Esempio n. 1
0
 def init(self, x, aux=None, forward=False):
     # For other norms, we pass in the BoundedTensor objects directly.
     x_L = x
     x_U = x
     if not forward:
         return LinearBound(None, None, None, None, x_L, x_U), x, None
     batch_size = x.shape[0]
     dim = x.reshape(batch_size, -1).shape[-1]
     eye = torch.eye(dim).to(x.device).unsqueeze(0).repeat(batch_size, 1, 1)
     lw = eye.reshape(batch_size, dim, *x.shape[1:])
     lb = torch.zeros_like(x).to(x.device)
     uw, ub = lw.clone(), lb.clone()
     return LinearBound(lw, lb, uw, ub, x_L, x_U), x, None
Esempio n. 2
0
 def _init_forward(self, root, dim_in):
     if dim_in == 0:
         raise ValueError(
             "At least one node should have a specified perturbation")
     prev_dim_in = 0
     batch_size = root[0].value.shape[0]
     for i in range(len(root)):
         if root[i].perturbation is not None:
             shape = root[i].linear.lw.shape
             device = root[i].linear.lw.device
             root[i].linear = root[i].linear._replace(
                 lw=torch.cat([
                     torch.zeros(
                         shape[0], prev_dim_in, *shape[2:], device=device),
                     root[i].linear.lw,
                     torch.zeros(shape[0],
                                 dim_in - shape[1],
                                 *shape[2:],
                                 device=device)
                 ],
                              dim=1),
                 uw=torch.cat([
                     torch.zeros(
                         shape[0], prev_dim_in, *shape[2:], device=device),
                     root[i].linear.uw,
                     torch.zeros(shape[0],
                                 dim_in - shape[1] - prev_dim_in,
                                 *shape[2:],
                                 device=device)
                 ],
                              dim=1))
             if i >= self.num_global_inputs:
                 root[i].forward_value = root[i].forward_value.unsqueeze(
                     0).repeat(*([batch_size] +
                                 [1] * len(self.forward_value.shape)))
             prev_dim_in += shape[1]
         else:
             fv = root[i].forward_value
             shape = fv.shape
             if root[i].from_input:
                 w = torch.zeros(shape[0],
                                 dim_in,
                                 *shape[1:],
                                 device=self.device)
             else:
                 w = None
             b = fv
             root[i].linear = LinearBound(w, b, w, b, b, b)
             root[i].lower = root[i].upper = b
             root[i].interval = (root[i].lower, root[i].upper)
Esempio n. 3
0
    def init(self, x, aux=None, forward=False):
        if forward:
            raise NotImplementedError()

        x_np = x.cpu().numpy()
        original_shape = x_np.shape
        x_np = np.reshape(x_np, (-1, original_shape[-1]))
        interval_bounds = self.transformation.transform(x_np, self.params)
        lb = np.reshape(interval_bounds.lower_bound, original_shape)
        ub = np.reshape(interval_bounds.upper_bound, original_shape)
        lb = torch.tensor(lb, device=x.device)
        ub = torch.tensor(ub, device=x.device)
        assert x.size() == lb.size() and x.size() == ub.size(), \
            f"bounds must have the same shape as x. Got x:{x.size()}, lb:{lb.size()}, ub:{ub.size()}"
        return LinearBound(None, None, None, None, lb, ub), x, None
Esempio n. 4
0
def test():
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)

    dim_input = 11
    
    # multiplication of [batch_size, dim_input] and [dim_output, dim_input]^T
    weight = torch.randn(dim_output, dim_input, device=device)
    bias = torch.randn(dim_output, device=device)
    data_in = torch.randn(batch_size, dim_input, device=device)
    data_in_delta = torch.randn(batch_size, dim_input, device=device)
    dummy_in = Dummy(data_in - torch.abs(data_in_delta), data_in + torch.abs(data_in_delta), True)
    dummy_weight = Dummy(weight)
    dummy_bias = Dummy(bias)

    op = BoundLinear(
        input_name=[None, None, None], 
        name=None, ori_name=None, attr=None, 
        inputs=[dummy_in, dummy_weight, dummy_bias],
        output_index=0, options={}, device=device)

    # test `forward`
    data_out = op(data_in, weight, bias)
    assert equal(data_out, data_in.matmul(weight.t()) + bias)

    # test `bound_backward`
    # The `transpose` here to make the randomization consistent with the previous reference.
    # It can be removed once a new reference is generated.
    last_lA = torch.randn(batch_size, dim_final, dim_output, device=device).transpose(0, 1)
    last_uA = torch.randn(batch_size, dim_final, dim_output, device=device).transpose(0, 1)
    A, lbias, ubias = op.bound_backward(last_lA, last_uA, *op.inputs)
    assert equal(A[0][0], last_lA.matmul(weight))
    assert equal(A[0][1], last_uA.matmul(weight))
    assert equal(lbias, last_lA.matmul(bias))
    assert equal(ubias, last_uA.matmul(bias))

    # test `bound_forward`
    # note that the upper bound may be actually smaller than the lower bound
    # in these dummy linear bounds
    bound_in = LinearBound(
        lw=torch.randn(batch_size, dim_final, dim_input, device=device),
        lb=torch.randn(batch_size, dim_input, device=device),
        uw=torch.randn(batch_size, dim_final, dim_input, device=device),
        ub=torch.randn(batch_size, dim_input, device=device),
        lower=None, upper=None)
    bound_weight = LinearBound(None, None, None, None, dummy_weight.lower, dummy_weight.upper)
    bound_bias = LinearBound(None, None, None, None, dummy_bias.lower, dummy_bias.upper)
    bound_out = op.bound_forward(dim_final, bound_in, bound_weight, bound_bias)
    assert equal(bound_out.lw, 
        bound_in.lw.matmul(weight.t().clamp(min=0)) + bound_in.uw.matmul(weight.t().clamp(max=0)))
    assert equal(bound_out.uw, 
        bound_in.uw.matmul(weight.t().clamp(min=0)) + bound_in.lw.matmul(weight.t().clamp(max=0)))
    assert equal(bound_out.lb, 
        bound_in.lb.matmul(weight.t().clamp(min=0)) + bound_in.ub.matmul(weight.t().clamp(max=0)) + bias)
    assert equal(bound_out.ub, 
        bound_in.ub.matmul(weight.t().clamp(min=0)) + bound_in.lb.matmul(weight.t().clamp(max=0)) + bias)

    # test `interval_propagate`
    bound_in = (
        torch.randn(*data_in.shape, device=device), 
        torch.randn(*data_in.shape, device=device))
    bound_weight = (bound_weight.lower, bound_weight.upper)
    bound_bias = (bound_bias.lower, bound_bias.upper)
    bound_out = op.interval_propagate(bound_in, bound_weight, bound_bias)
    assert equal(bound_out[0], 
        bound_in[0].matmul(weight.t().clamp(min=0)) + bound_in[1].matmul(weight.t().clamp(max=0)) + bias)
    assert equal(bound_out[1], 
        bound_in[1].matmul(weight.t().clamp(min=0)) + bound_in[0].matmul(weight.t().clamp(max=0)) + bias)

    # test weight perturbation
    # `bound_backward`
    ptb_weight = torch.randn(weight.shape)
    op.inputs[1].upper += ptb_weight
    op.inputs[1].perturbed = True
    op.inputs[2].perturbation = None # no perturbation on bias
    A, lbias, ubias = op.bound_backward(last_lA, last_uA, *op.inputs)
    # `interval_propagate`
    bound_weight = (op.inputs[1].lower, op.inputs[1].upper)
    bound_out = op.interval_propagate(bound_in, bound_weight, bound_bias)
    if args.gen_ref:
        if not os.path.exists('data/bound_ops'):
            os.mkdir('data/bound_ops')
        with open('data/bound_ops/weight_ptb.pkl', 'wb') as file:
            pickle.dump((A, lbias, ubias, bound_out), file)
    with open('data/bound_ops/weight_ptb.pkl', 'rb') as file:
        A_ref, lbias_ref, ubias_ref, bound_out_ref = pickle.load(file)
    for i in range(3):
        for j in range(2):
            if A_ref[i][j] is not None:
                ref = A_ref[i][j]
                # legacy reference
                if ref.shape[0] == batch_size:
                    ref = ref.transpose(0, 1)
                assert equal(A[i][j], ref) 
    lbias, ubias = lbias.transpose(0, 1), ubias.transpose(0, 1)
    assert equal(lbias, lbias_ref)
    assert equal(ubias, ubias_ref)
    assert equal(bound_out[0], bound_out_ref[0]) and equal(bound_out[1], bound_out_ref[1])
Esempio n. 5
0
    def init(self, x, aux=None, forward=False):
        tokens, batch = aux
        self.tokens = tokens # DEBUG
        assert(len(x.shape) == 3)
        batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2]

        max_pos = 1
        can_be_replaced = np.zeros((batch_size, length), dtype=np.bool)

        self._build_substitution(batch)

        for t in range(batch_size):
            cnt = 0
            candidates = batch[t]['candidates']
            # for transformers
            if tokens[t][0] == '[CLS]':
                candidates = [[]] + candidates + [[]]
            for i in range(len(tokens[t])):
                if tokens[t][i] == '[UNK]' or \
                        len(candidates[i]) == 0 or tokens[t][i] != candidates[i][0]:
                    continue
                for w in candidates[i][1:]:
                    if w in self.model.vocab:
                        can_be_replaced[t][i] = True
                        cnt += 1
                        break
            max_pos = max(max_pos, cnt)

        dim = max_pos * dim_word
        if forward:
            eye = torch.eye(dim_word).to(x.device)
            lw = torch.zeros(batch_size, dim, length, dim_word).to(x.device)
            lb = torch.zeros_like(x).to(x.device)   
        x_new = []     
        word_embeddings = self.model.word_embeddings.weight
        vocab = self.model.vocab
        x_rep = [[[] for i in range(length)] for t in range(batch_size)]
        max_num_cand = 1
        for t in range(batch_size):
            candidates = batch[t]['candidates']
            # for transformers
            if tokens[t][0] == '[CLS]':
                candidates = [[]] + candidates + [[]]  
            cnt = 0    
            for i in range(length):
                if can_be_replaced[t][i]:
                    word_embed = word_embeddings[vocab[tokens[t][i]]]
                    if forward:
                        lw[t, (cnt * dim_word):((cnt + 1) * dim_word), i, :] = eye
                        lb[t, i, :] = x[t, i, :] - word_embed
                    for w in candidates[i][1:]:
                        if w in self.model.vocab:
                            x_rep[t][i].append(
                                word_embeddings[self.model.vocab[w]])
                    max_num_cand = max(max_num_cand, len(x_rep[t][i]))
                    cnt += 1
                else:
                    if forward:
                        lb[t, i, :] = x[t, i, :]
        if forward:
            uw, ub = lw, lb
        else:
            lw = lb = uw = ub = None
        zeros = torch.zeros(dim_word, device=x.device)
        
        x_rep_, mask = [], []
        for t in range(batch_size):
            for i in range(length):
                x_rep_ += x_rep[t][i] + [zeros] * (max_num_cand - len(x_rep[t][i]))
                mask += [1] * len(x_rep[t][i]) + [0] * (max_num_cand - len(x_rep[t][i]))
        x_rep_ = torch.cat(x_rep_).reshape(batch_size, length, max_num_cand, dim_word)
        mask = torch.tensor(mask, dtype=torch.float32, device=x.device)\
            .reshape(batch_size, length, max_num_cand)
        x_rep_ = x_rep_ * self.eps + x.unsqueeze(2) * (1 - self.eps)
        
        inf = 1e20
        lower = torch.min(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * inf, dim=2).values
        upper = torch.max(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * (-inf), dim=2).values
        lower = torch.min(lower, x)
        upper = torch.max(upper, x)

        return LinearBound(lw, lb, uw, ub, lower, upper), x, (x_rep_, mask, can_be_replaced)
Esempio n. 6
0
    def _forward_general(self,
                         C=None,
                         node=None,
                         root=None,
                         dim_in=None,
                         concretize=False):
        if hasattr(node, 'lower'):
            return node.lower, node.upper

        if not node.from_input:
            w = None
            b = node.forward_value
            node.linear = LinearBound(w, b, w, b, b, b)
            node.lower = node.upper = b
            node.interval = (node.lower, node.upper)
            return node.interval

        if not hasattr(node, 'linear'):
            for l_pre in node.input_name:
                l = self._modules[l_pre]
                if not hasattr(l, 'linear'):
                    self._forward_general(node=l, root=root, dim_in=dim_in)

            inp = [self._modules[l_pre].linear for l_pre in node.input_name]
            node.linear = node.bound_forward(dim_in, *inp)

        lw, uw = node.linear.lw, node.linear.uw
        lower, upper = node.linear.lb, node.linear.ub
        if C is not None:
            C_pos, C_neg = C.clamp(min=0), C.clamp(max=0)
            _lw = torch.matmul(lw, C_pos.transpose(-1, -2)) + torch.matmul(
                uw, C_neg.transpose(-1, -2))
            _uw = torch.matmul(uw, C_pos.transpose(-1, -2)) + torch.matmul(
                lw, C_neg.transpose(-1, -2))
            lw, uw = _lw, _uw
            _lower = torch.matmul(lower.unsqueeze(1), C_pos.transpose(-1, -2)) + \
                torch.matmul(upper.unsqueeze(1), C_neg.transpose(-1, -2))
            _upper = torch.matmul(upper.unsqueeze(1), C_pos.transpose(-1, -2)) + \
                torch.matmul(lower.unsqueeze(1), C_neg.transpose(-1, -2))
            lower, upper = _lower.squeeze(-1), _upper.squeeze(-1)

        if concretize:
            if node.linear.lw is not None:
                prev_dim_in = 0
                batch_size = lw.shape[0]
                assert (len(lw.shape) > 1)
                lA = lw.reshape(batch_size, dim_in, -1).transpose(1, 2)
                uA = uw.reshape(batch_size, dim_in, -1).transpose(1, 2)
                for i in range(len(root)):
                    if root[i].perturbation is not None:
                        _lA = lA[:, :, prev_dim_in:(prev_dim_in + root[i].dim)]
                        _uA = uA[:, :, prev_dim_in:(prev_dim_in + root[i].dim)]
                        lower = lower + root[i].perturbation.concretize(
                            root[i].center, _lA, sign=-1,
                            aux=root[i].aux).view(lower.shape)
                        upper = upper + root[i].perturbation.concretize(
                            root[i].center, _uA, sign=+1,
                            aux=root[i].aux).view(upper.shape)
                        prev_dim_in += root[i].dim
                if C is None:
                    node.linear = node.linear._replace(lower=lower,
                                                       upper=upper)
            if C is None:
                node.lower, node.upper = lower, upper
            return lower, upper