예제 #1
0
    def concretize_2bounds(self, x, Ax, sum_b, sign=-1, y=[]):
        # only support linear layer so far
        if Ax is None:
            return None
        batch = x.shape[0]
        _tmp_Ay = 0
        _tmp_center = 0
        if sign == -1:
            for i in range(len(y)):
                logger.debug(y[i].shape)
                logger.debug(y[i].lA_y.shape)
                Ay = y[i].lA_y
                Ay = Ay.reshape(*Ay.shape[:2], -1)
                _tmp_Ay -= torch.norm(Ay, self.dual_norm, -1) * y[i].eps
                _tmp_center += Ay.bmm(
                    y[i].reshape(-1).unsqueeze(-1).unsqueeze(0).repeat(
                        batch, 1, 1))
        elif sign == 1:
            for i in range(len(y)):
                Ay = y[i].uA_y
                Ay = Ay.reshape(*Ay.shape[:2], -1)
                _tmp_Ay += torch.norm(Ay, self.dual_norm, -1) * y[i].eps
                _tmp_center += Ay.bmm(
                    y[i].reshape(-1).unsqueeze(-1).unsqueeze(0).repeat(
                        batch, 1, 1))

        _tmp_center += Ax.bmm(x.reshape(
            batch, -1).unsqueeze(-1)) + sum_b.unsqueeze(-1)
        bound = _tmp_center.squeeze(-1) + sign * torch.norm(
            Ax, self.dual_norm, -1) * self.eps + _tmp_Ay

        return bound
예제 #2
0
    def _convert(self, model, global_input):
        if self.verbose:
            logger.info('Converting the model...')

        if not isinstance(global_input, tuple):
            global_input = (global_input, )
        self.num_global_inputs = len(global_input)

        nodesOP, nodesIO = self._convert_nodes(model, global_input)
        global_input = tuple([i.to(self.device) for i in global_input])

        while True:
            self._build_graph(nodesOP, nodesIO)
            self.forward(*global_input)
            nodesOP, nodesIO, found_complex = self._split_complex(
                nodesOP, nodesIO)
            if not found_complex: break

        for node in self.nodes:
            for p in list(node.named_parameters()):
                if node.ori_name not in self._parameters:
                    # For parameter or input nodes, use their original name directly
                    self._parameters[node.ori_name] = p[1]

        logger.debug('NodesOP:')
        for node in nodesOP:
            logger.debug('{}'.format(node._replace(param=None)))
        logger.debug('NodesIO')
        for node in nodesIO:
            logger.debug('{}'.format(node._replace(param=None)))

        if self.verbose:
            logger.info('Model converted to support bounds')
예제 #3
0
    def _convert(self, model, global_input):
        if self.verbose:
            logger.info('Converting the model...')

        if not isinstance(global_input, tuple):
            global_input = (global_input, )
        self.num_global_inputs = len(global_input)

        nodesOP, nodesIO = self._convert_nodes(model, global_input)
        global_input = tuple([i.to(self.device) for i in global_input])

        while True:
            self._build_graph(nodesOP, nodesIO)
            self.forward(*global_input)  # running means/vars changed
            nodesOP, nodesIO, found_complex = self._split_complex(
                nodesOP, nodesIO)
            if not found_complex: break

        self._get_node_name_map()

        # load self.ori_state_dict again to avoid the running means/vars changed during forward()
        self.load_state_dict(self.ori_state_dict)
        model.load_state_dict(self.ori_state_dict)
        delattr(self, 'ori_state_dict')

        logger.debug('NodesOP:')
        for node in nodesOP:
            logger.debug('{}'.format(node._replace(param=None)))
        logger.debug('NodesIO')
        for node in nodesIO:
            logger.debug('{}'.format(node._replace(param=None)))

        if self.verbose:
            logger.info('Model converted to support bounds')
예제 #4
0
    def _convert(self, model, global_input):
        if self.verbose:
            logger.info('Converting the model...')

        if not isinstance(global_input, tuple):
            global_input = (global_input, )
        self.num_global_inputs = len(global_input)
        self.device = global_input[0].device

        nodesOP, nodesIO = self._convert_nodes(model, global_input)

        while True:
            self._build_graph(nodesOP, nodesIO)
            self.forward(*global_input)
            nodesOP, nodesIO, found_complex = self._split_complex(
                nodesOP, nodesIO)
            if not found_complex: break

        for node in self.nodes:
            for p in list(node.named_parameters()):
                self.register_parameter('{}/{}'.format(node.name, p[0]), p[1])

        logger.debug('NodesOP:')
        for node in nodesOP:
            logger.debug('{}'.format(node._replace(param=None)))
        logger.debug('NodesIO')
        for node in nodesIO:
            logger.debug('{}'.format(node._replace(param=None)))

        if self.verbose:
            logger.info('Model converted to support bounds')
예제 #5
0
    def _convert_nodes(self, model, global_input):
        global_input_cpu = tuple([i.to("cpu") for i in list(global_input)])
        model.train()
        model.to('cpu')
        nodesOP, nodesIO = get_graph_params(model, global_input_cpu)
        model.to(self.device)
        for i in range(0, len(nodesIO)):
            if nodesIO[i].param is not None:
                nodesIO[i] = nodesIO[i]._replace(
                    param=nodesIO[i].param.to(self.device))

        for n in range(len(nodesOP)):
            attr = nodesOP[n].attr
            inputs = self._get_node_input(nodesOP, nodesIO, nodesOP[n])

            if nodesOP[n].op in bound_op_map:
                if nodesOP[n].op == 'onnx::BatchNormalization':
                    # BatchNormalization node needs model.training flag to set running mean and vars
                    nodesOP[n] = nodesOP[n]._replace(
                        bound_node=bound_op_map[nodesOP[n].op]
                        (nodesOP[n].inputs, nodesOP[n].name, attr, inputs,
                         nodesOP[n].output_index, self.device, model.training))
                else:
                    nodesOP[n] = nodesOP[n]._replace(
                        bound_node=bound_op_map[nodesOP[n].op](
                            nodesOP[n].inputs, nodesOP[n].name, attr, inputs,
                            nodesOP[n].output_index, self.device))
            else:
                print(nodesOP[n])
                raise NotImplementedError('Unsupported operation {}'.format(
                    nodesOP[n].op))

            if self.verbose:
                logger.debug(
                    'Convert complete for {} with operation: {}'.format(
                        nodesOP[n].name, nodesOP[n].op))

        for i in range(0, len(global_input)):
            nodesIO[i] = nodesIO[i]._replace(param=global_input[i],
                                             bound_node=BoundInput(
                                                 nodesIO[i].inputs,
                                                 nodesIO[i].name,
                                                 value=global_input[i]))
            nodesIO[i].bound_node.method = 'forward'
        for i in range(len(global_input), len(nodesIO)):
            nodesIO[i] = nodesIO[i]._replace(bound_node=BoundParams(
                nodesIO[i].inputs, nodesIO[i].name, value=nodesIO[i].param))
            nodesIO[i].bound_node.method = 'forward'

        return nodesOP, nodesIO
예제 #6
0
def parse_module(module,
                 inputs,
                 param_exclude=".*AuxLogits.*",
                 param_include=None):
    params = _get_jit_params(module,
                             param_exclude=param_exclude,
                             param_include=param_include)
    if version.parse(torch.__version__) < version.parse("1.4.0"):
        trace, out = torch.jit.get_trace_graph(module, inputs)
        torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
        trace_graph = trace.graph()
    else:
        # _get_trace_graph becomes an internal function in version >= 1.4.0
        trace, out = torch.jit._get_trace_graph(module, inputs)
        # this is not present in older torch
        from torch.onnx.symbolic_helper import _set_opset_version
        if version.parse(torch.__version__) < version.parse("1.5.0"):
            _set_opset_version(11)
        else:
            _set_opset_version(12)
        trace_graph = torch.onnx._optimize_trace(
            trace, torch.onnx.OperatorExportTypes.ONNX)

    logger.debug('trace_graph: {}'.format(trace_graph))

    if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0:
        print("Graph before ONNX convertion:")
        print(trace)
        print("ONNX graph:")
        print(trace_graph)

    if not isinstance(inputs, tuple):
        inputs = (inputs, )

    nodesOP, nodesIn, nodesOut = parse_graph(trace_graph, tuple(inputs),
                                             tuple(params))

    for i in range(len(nodesOP)):
        param_in = OrderedDict()
        for inp in nodesOP[i].inputs:
            for n in nodesIn:
                if inp == n.name:
                    param_in.update({inp: n.param})
        nodesOP[i] = nodesOP[i]._replace(param=param_in)

    template = get_output_template(out)

    return nodesOP, nodesIn, nodesOut, template
예제 #7
0
    def _backward_general(self,
                          C=None,
                          node=None,
                          root=None,
                          bound_lower=True,
                          bound_upper=True):
        logger.debug('Backward from {} {}'.format(node.name, node))

        degree_out = {}
        for l in self.nodes:
            l.bounded = True
            l.lA = l.uA = None
            degree_out[l.name] = 0
        queue = [node]
        while len(queue) > 0:
            l = queue[0]
            queue = queue[1:]
            for l_pre in l.input_name:
                degree_out[l_pre] += 1
                if self.node_dict[l_pre].bounded:
                    self.node_dict[l_pre].bounded = False
                    queue.append(self.node_dict[l_pre])
        node.bounded = True
        node.lA = C if bound_lower else None
        node.uA = C if bound_upper else None
        lb = ub = torch.tensor(0.).to(C.device)

        queue = [node]
        while len(queue) > 0:
            l = queue[0]  # backward from l
            queue = queue[1:]
            l.bounded = True

            if l.name in self.root_name or l == root: continue

            for l_pre in l.input_name:
                _l = self.node_dict[l_pre]
                degree_out[l_pre] -= 1
                if degree_out[l_pre] == 0:
                    queue.append(_l)

            if l.lA is not None or l.uA is not None:

                def add_bound(node, lA, uA):
                    if lA is not None:
                        node.lA = lA if node.lA is None else (node.lA + lA)
                    if uA is not None:
                        node.uA = uA if node.uA is None else (node.uA + uA)

                input_nodes = [
                    self.node_dict[l_name] for l_name in l.input_name
                ]
                A, lower_b, upper_b = l.bound_backward(l.lA, l.uA,
                                                       *input_nodes)
                lb = lb + lower_b
                ub = ub + upper_b

                for i, l_pre in enumerate(l.input_name):
                    _l = self.node_dict[l_pre]
                    add_bound(_l, lA=A[i][0], uA=A[i][1])

        batch_size = C.shape[0]
        output_shape = node.forward_value.shape[1:]
        if node.forward_value.contiguous().view(batch_size,
                                                -1).shape[1] != C.shape[1]:
            output_shape = [-1]

        for i in range(len(root)):
            if root[i].lA is None and root[i].uA is None: continue
            logger.debug('concretize node: {} shape: {}'.format(
                root[i], root[i].lA.shape))
            lA = root[i].lA.reshape(batch_size, root[i].lA.shape[1],
                                    -1) if bound_lower else None
            uA = root[i].uA.reshape(batch_size, root[i].uA.shape[1],
                                    -1) if bound_upper else None
            if root[i].perturbation is not None:
                if isinstance(root[i], BoundParams):
                    # add batch_size dim for weights node
                    lb = lb + root[i].perturbation.concretize(
                        root[i].center.unsqueeze(0).repeat(
                            ([batch_size] + [1] * len(root[i].center.shape))),
                        lA,
                        sign=-1,
                        aux=root[i].aux) if bound_lower else None
                    ub = ub + root[i].perturbation.concretize(
                        root[i].center.unsqueeze(0).repeat(
                            ([batch_size] + [1] * len(root[i].center.shape))),
                        uA,
                        sign=+1,
                        aux=root[i].aux) if bound_upper else None
                else:
                    lb = lb + root[i].perturbation.concretize(
                        root[i].center, lA, sign=-1,
                        aux=root[i].aux) if bound_lower else None
                    ub = ub + root[i].perturbation.concretize(
                        root[i].center, uA, sign=+1,
                        aux=root[i].aux) if bound_upper else None
            elif i < self.num_global_inputs:
                lb = lb + root[i].lA.reshape(
                    batch_size, root[i].lA.shape[1], -1).bmm(
                        root[i].forward_value.view(batch_size, -1, 1)).squeeze(
                            -1) if bound_lower else None
                ub = ub + root[i].uA.reshape(
                    batch_size, root[i].uA.shape[1], -1).bmm(
                        root[i].forward_value.view(batch_size, -1, 1)).squeeze(
                            -1) if bound_upper else None
            else:
                lb = lb + root[i].lA.reshape(
                    batch_size,
                    root[i].lA.shape[1], -1).matmul(root[i].forward_value.view(
                        -1, 1)).squeeze(-1) if bound_lower else None
                ub = ub + root[i].uA.reshape(
                    batch_size,
                    root[i].uA.shape[1], -1).matmul(root[i].forward_value.view(
                        -1, 1)).squeeze(-1) if bound_upper else None

        node.lower = lb.view(batch_size, *
                             output_shape) if bound_lower else None
        node.upper = ub.view(batch_size, *
                             output_shape) if bound_upper else None
        return node.lower, node.upper
예제 #8
0
    def _backward_general(self,
                          C=None,
                          node=None,
                          root=None,
                          bound_lower=True,
                          bound_upper=True,
                          return_A=False,
                          average_A=False):
        _print_time = False

        degree_out = {}
        for l in self._modules.values():
            l.bounded = True
            l.lA = l.uA = None
            degree_out[l.name] = 0
        queue = [node]
        while len(queue) > 0:
            l = queue[0]
            queue = queue[1:]
            for l_pre in l.input_name:
                degree_out[l_pre] += 1  # calculate the out degree
                if self._modules[l_pre].bounded:
                    self._modules[l_pre].bounded = False
                    queue.append(self._modules[l_pre])
        node.bounded = True
        node.lA = C if bound_lower else None
        node.uA = C if bound_upper else None
        lb = ub = torch.tensor(0.).to(C.device)

        queue = [node]
        while len(queue) > 0:
            l = queue[0]  # backward from l
            queue = queue[1:]
            l.bounded = True

            if l.name in self.root_name or l == root: continue

            for l_pre in l.input_name:  # if all the succeeds are done, then we can turn to this node in the next iteration.
                _l = self._modules[l_pre]
                degree_out[l_pre] -= 1
                if degree_out[l_pre] == 0:
                    queue.append(_l)

            if l.lA is not None or l.uA is not None:

                def add_bound(node, lA, uA):
                    if lA is not None:
                        node.lA = lA if node.lA is None else (node.lA + lA)
                    if uA is not None:
                        node.uA = uA if node.uA is None else (node.uA + uA)

                input_nodes = [
                    self._modules[l_name] for l_name in l.input_name
                ]
                if _print_time:
                    start_time = time.time()
                logger.debug('Backward from {} to {}, {}'.format(
                    node.name, l.name, l))
                A, lower_b, upper_b = l.bound_backward(l.lA, l.uA,
                                                       *input_nodes)

                if _print_time:
                    time_elapsed = time.time() - start_time
                    if time_elapsed > 1e-3:
                        print(l, time_elapsed)
                lb = lb + lower_b
                ub = ub + upper_b

                for i, l_pre in enumerate(l.input_name):
                    _l = self._modules[l_pre]
                    add_bound(_l, lA=A[i][0], uA=A[i][1])

        batch_size = C.shape[0]
        output_shape = node.default_shape[1:]
        if np.prod(node.default_shape[1:]) != C.shape[1]:
            output_shape = [-1]

        if return_A:
            # return A matrix as a dict: {node.name: [A_lower, A_upper]}
            A_dict = {'bias': [lb, ub]}
            for i in range(len(root)):
                if root[i].lA is None and root[i].uA is None: continue
                A_dict.update({root[i].name: [root[i].lA, root[i].uA]})

        for i in range(len(root)):
            if root[i].lA is None and root[i].uA is None: continue
            if average_A and isinstance(root[i], BoundParams):
                A_shape = root[i].lA.shape if bound_lower else root[i].uA.shape
                lA = root[i].lA.mean(0, keepdim=True).repeat(
                    A_shape[0], *[1] *
                    len(A_shape[1:])) if bound_lower else None
                uA = root[i].uA.mean(0, keepdim=True).repeat(
                    A_shape[0], *[1] *
                    len(A_shape[1:])) if bound_upper else None
            else:
                lA = root[i].lA
                uA = root[i].uA
            if not isinstance(root[i].lA, eyeC):
                lA = root[i].lA.reshape(batch_size, root[i].lA.shape[1],
                                        -1) if bound_lower else None
            if not isinstance(root[i].uA, eyeC):
                uA = root[i].uA.reshape(batch_size, root[i].uA.shape[1],
                                        -1) if bound_upper else None
            if root[i].perturbation is not None:
                if isinstance(root[i], BoundParams):
                    # add batch_size dim for weights node
                    lb = lb + root[i].perturbation.concretize(
                        root[i].center.unsqueeze(0),
                        lA,
                        sign=-1,
                        aux=root[i].aux) if bound_lower else None
                    ub = ub + root[i].perturbation.concretize(
                        root[i].center.unsqueeze(0),
                        uA,
                        sign=+1,
                        aux=root[i].aux) if bound_upper else None
                else:
                    lb = lb + root[i].perturbation.concretize(
                        root[i].center, lA, sign=-1,
                        aux=root[i].aux) if bound_lower else None
                    ub = ub + root[i].perturbation.concretize(
                        root[i].center, uA, sign=+1,
                        aux=root[i].aux) if bound_upper else None
            elif i < self.num_global_inputs:
                if not isinstance(lA, eyeC):
                    lb = lb + lA.bmm(root[i].value.view(batch_size, -1, 1)
                                     ).squeeze(-1) if bound_lower else None
                else:
                    lb = lb + root[i].value.view(batch_size,
                                                 -1) if bound_lower else None
                if not isinstance(uA, eyeC):
                    ub = ub + uA.bmm(root[i].value.view(batch_size, -1, 1)
                                     ).squeeze(-1) if bound_upper else None
                else:
                    ub = ub + root[i].value.view(batch_size,
                                                 -1) if bound_upper else None
            else:
                if not isinstance(lA, eyeC):
                    lb = lb + lA.matmul(root[i].param.view(
                        -1, 1)).squeeze(-1) if bound_lower else None
                else:
                    lb = lb + root[i].param.view(1,
                                                 -1) if bound_lower else None
                if not isinstance(uA, eyeC):
                    ub = ub + uA.matmul(root[i].param.view(
                        -1, 1)).squeeze(-1) if bound_upper else None
                else:
                    ub = ub + root[i].param.view(1,
                                                 -1) if bound_upper else None

        node.lower = lb.view(batch_size, *
                             output_shape) if bound_lower else None
        node.upper = ub.view(batch_size, *
                             output_shape) if bound_upper else None

        if return_A: return node.lower, node.upper, A_dict
        return node.lower, node.upper
예제 #9
0
    def _convert_nodes(self, model, global_input):
        global_input_cpu = tuple([i.to('cpu') for i in list(global_input)])
        model.train()
        model.to('cpu')
        nodesOP, nodesIO = get_graph_params(model, global_input_cpu)
        model.to(self.device)
        for i in range(0, len(nodesIO)):
            if nodesIO[i].param is not None:
                nodesIO[i] = nodesIO[i]._replace(
                    param=nodesIO[i].param.to(self.device))

        # FIXME: better way to handle buffers, do not hard-code it for BN!
        # Other nodes can also have buffers.
        bn_nodes = []
        for n in range(len(nodesOP)):
            if nodesOP[n].op == 'onnx::BatchNormalization':
                bn_nodes.extend(
                    nodesOP[n].inputs[3:]
                )  # collect names of  running_mean and running_var

        # Convert input nodes and parameters.
        for i in range(0, len(global_input)):
            nodesIO[i] = nodesIO[i]._replace(
                param=global_input[i],
                bound_node=BoundInput(nodesIO[i].inputs,
                                      nodesIO[i].name,
                                      nodesIO[i].ori_name,
                                      value=global_input[i],
                                      perturbation=nodesIO[i].perturbation))
        for i in range(len(global_input), len(nodesIO)):
            if nodesIO[i].name in bn_nodes:
                nodesIO[i] = nodesIO[i]._replace(bound_node=BoundBuffers(
                    nodesIO[i].inputs,
                    nodesIO[i].name,
                    nodesIO[i].ori_name,
                    value=nodesIO[i].param,
                    perturbation=nodesIO[i].perturbation))
            else:
                nodesIO[i] = nodesIO[i]._replace(bound_node=BoundParams(
                    nodesIO[i].inputs,
                    nodesIO[i].name,
                    nodesIO[i].ori_name,
                    value=nodesIO[i].param,
                    perturbation=nodesIO[i].perturbation))

        # Convert other operation nodes.
        for n in range(len(nodesOP)):
            attr = nodesOP[n].attr
            inputs, ori_names = self._get_node_input(nodesOP, nodesIO,
                                                     nodesOP[n])

            if nodesOP[n].op in bound_op_map:
                if nodesOP[n].op == 'onnx::BatchNormalization':
                    # BatchNormalization node needs model.training flag to set running mean and vars
                    # set training=False to avoid wrongly updating running mean/vars during bound wrapper
                    nodesOP[n] = nodesOP[n]._replace(
                        bound_node=bound_op_map[nodesOP[n].op]
                        (nodesOP[n].inputs, nodesOP[n].name, None, attr,
                         inputs, nodesOP[n].output_index, self.device, False))
                elif nodesOP[n].op in [
                        'onnx::Relu', 'onnx::LeakyRelu', 'onnx::Exp'
                ]:
                    nodesOP[n] = nodesOP[n]._replace(
                        bound_node=bound_op_map[nodesOP[n].op]
                        (nodesOP[n].inputs, nodesOP[n].name, None, attr,
                         inputs, nodesOP[n].output_index, self.device,
                         self.bound_opts))
                else:
                    nodesOP[n] = nodesOP[n]._replace(
                        bound_node=bound_op_map[nodesOP[n].op](
                            nodesOP[n].inputs, nodesOP[n].name, None, attr,
                            inputs, nodesOP[n].output_index, self.device))
            else:
                print(nodesOP[n])
                raise NotImplementedError('Unsupported operation {}'.format(
                    nodesOP[n].op))

            if self.verbose:
                logger.debug(
                    'Convert complete for {} with operation: {}'.format(
                        nodesOP[n].name, nodesOP[n].op))

        return nodesOP, nodesIO
예제 #10
0
    def weights_backward_general(self,
                                 norm=np.inf,
                                 x=None,
                                 eps=None,
                                 C=None,
                                 ptb=None,
                                 node=None,
                                 root=None):
        assert (len(root) == 1)
        root = root[0]

        torch.cuda.empty_cache()

        logger.debug('Backward from {} {}'.format(node.name, node))

        degree_out = {}
        for l in self.nodes:
            l.bounded = True
            l.lA = l.uA = None
            degree_out[l.name] = 0
        queue = [node]
        while len(queue) > 0:
            l = queue[0]
            queue = queue[1:]
            for l_pre in l.input_name:
                degree_out[l_pre] += 1
                if self.node_dict[l_pre].bounded:
                    self.node_dict[l_pre].bounded = False
                    queue.append(self.node_dict[l_pre])
        node.bounded = True
        node.uA = C
        node.lA = C
        upper_sum_b = lower_sum_b = torch.tensor(0.).to(C.device)

        queue = [node]
        nodes_perturb_list = []
        while len(queue) > 0:
            l = queue[0]
            queue = queue[1:]
            l.bounded = True

            if l in self.root_name or l == root: continue

            for l_pre in l.input_name:
                _l = self.node_dict[l_pre]
                degree_out[l_pre] -= 1
                if degree_out[l_pre] == 0:
                    queue.append(_l)

            if l.uA is not None:

                def add_bound(node, uA, lA):
                    node.uA = uA if node.uA is None else (node.uA + uA)
                    node.lA = lA if node.lA is None else (node.lA + lA)

                logger.debug('Backward at {} {}'.format(l.name, l))

                if len(l.input_name) == 1:
                    input_node = self.node_dict[l.input_name[0]]
                    if hasattr(l, 'nonlinear') and l.nonlinear is True:
                        lA, lower_b, uA, upper_b = l.bound_backward(
                            l.lA, l.uA, input_node)
                        A = [(uA, lA)]
                    else:
                        [(lA_x, uA_x), (lA_y, uA_y)
                         ], upper_b, lower_b = l.two_bounds_backward(
                             l.lA, l.uA, input_node, l)
                        A = [(lA_x, uA_x)]  # y is weights, x is input
                        l.weight.lA_y, l.weight.uA_y = lA_y, uA_y
                        nodes_perturb_list.append(l.weight)
                else:
                    A, lower_b, upper_b = l.bound_backward(l.lA, l.uA)
                upper_sum_b = upper_sum_b + upper_b
                lower_sum_b = lower_sum_b + lower_b

                for i, l_pre in enumerate(l.input_name):
                    _l = self.node_dict[l_pre]
                    add_bound(_l, uA=A[i][0], lA=A[i][1])

        batch_size = C.shape[0]
        output_shape = node.forward_value.shape[1:]
        if node.forward_value.contiguous().view(batch_size,
                                                -1).shape[1] != C.shape[1]:
            output_shape = [-1]

        if node.from_input:
            lb = ptb.concretize_2bounds(x,
                                        root.lA,
                                        lower_sum_b,
                                        sign=-1,
                                        y=nodes_perturb_list)
            ub = ptb.concretize_2bounds(x,
                                        root.uA,
                                        upper_sum_b,
                                        sign=+1,
                                        y=nodes_perturb_list)
        else:
            lb, ub = lower_sum_b.reshape(-1), upper_sum_b.reshape(-1)

        return lb.view(batch_size,
                       *output_shape), ub.view(batch_size, *output_shape)
예제 #11
0
    def _backward_general(self,
                          norm=np.inf,
                          x=None,
                          C=None,
                          ptb=None,
                          node=None,
                          root=None):
        logger.debug('Backward from {} {}'.format(node.name, node))

        degree_out = {}
        for l in self.nodes:
            l.bounded = True
            l.lA = l.uA = None
            degree_out[l.name] = 0
        queue = [node]
        while len(queue) > 0:
            l = queue[0]
            queue = queue[1:]
            for l_pre in l.input_name:
                degree_out[l_pre] += 1
                if self.node_dict[l_pre].bounded:
                    self.node_dict[l_pre].bounded = False
                    queue.append(self.node_dict[l_pre])
        node.bounded = True
        node.lA = node.uA = C
        lb = ub = torch.tensor(0.).to(C.device)

        queue = [node]
        while len(queue) > 0:
            l = queue[0]  # backward from l
            queue = queue[1:]
            l.bounded = True

            if l.name in self.root_name or l == root: continue

            for l_pre in l.input_name:
                _l = self.node_dict[l_pre]
                degree_out[l_pre] -= 1
                if degree_out[l_pre] == 0:
                    queue.append(_l)

            if l.uA is not None:

                def add_bound(node, lA, uA):
                    node.lA = lA if node.lA is None else (node.lA + lA)
                    node.uA = uA if node.uA is None else (node.uA + uA)

                logger.debug('Backward at {} {}'.format(l.name, l))

                input_nodes = [
                    self.node_dict[l_name] for l_name in l.input_name
                ]

                if len(l.input_name) == 1:
                    lA, lower_b, uA, upper_b = l.bound_backward(
                        l.lA, l.uA, *input_nodes)
                    A = [(lA, uA)]
                else:
                    A, lower_b, upper_b = l.bound_backward(
                        l.lA, l.uA, *input_nodes)
                ub = ub + upper_b
                lb = lb + lower_b

                for i, l_pre in enumerate(l.input_name):
                    _l = self.node_dict[l_pre]
                    add_bound(_l, lA=A[i][0], uA=A[i][1])

        batch_size = C.shape[0]
        output_shape = node.forward_value.shape[1:]
        if node.forward_value.contiguous().view(batch_size,
                                                -1).shape[1] != C.shape[1]:
            output_shape = [-1]

        for r in root:
            if r.lA is None: continue
            if isinstance(r.linear, LinearBound):
                uA = r.uA.reshape(batch_size, r.uA.shape[1], -1).matmul(
                    r.linear.uw.view(batch_size, r.linear.uw.shape[1],
                                     -1).transpose(1, 2))
                ub = ub + r.uA.reshape(batch_size, r.uA.shape[1], -1).matmul(
                    r.linear.ub.view(batch_size, -1, 1)).squeeze(-1)

                lA = r.lA.reshape(batch_size, r.lA.shape[1], -1).matmul(
                    r.linear.lw.view(batch_size, r.linear.lw.shape[1],
                                     -1).transpose(1, 2))
                lb = lb + r.lA.reshape(batch_size, r.lA.shape[1], -1).matmul(
                    r.linear.lb.view(batch_size, -1, 1)).squeeze(-1)

                lb = lb + ptb.concretize(x, lA, torch.zeros_like(lb), sign=-1)
                ub = ub + ptb.concretize(x, uA, torch.zeros_like(ub), sign=+1)
            else:
                lb = lb + r.lA.reshape(batch_size, r.lA.shape[1], -1).matmul(
                    r.forward_value.view(batch_size, -1, 1)).squeeze(-1)
                ub = ub + r.uA.reshape(batch_size, r.uA.shape[1], -1).matmul(
                    r.forward_value.view(batch_size, -1, 1)).squeeze(-1)

        node.lower = lb.view(batch_size, *output_shape)
        node.upper = ub.view(batch_size, *output_shape)
        return node.lower, node.upper
예제 #12
0
 def __init__(self, norm, eps):
     self.norm = norm
     self.eps = eps  # eps of input x
     self.dual_norm = 1 if (norm == np.inf) else (np.float64(1.0) /
                                                  (1 - 1.0 / self.norm))
     logger.debug('Using l{} norm to concretize'.format(self.dual_norm))