def concretize_2bounds(self, x, Ax, sum_b, sign=-1, y=[]): # only support linear layer so far if Ax is None: return None batch = x.shape[0] _tmp_Ay = 0 _tmp_center = 0 if sign == -1: for i in range(len(y)): logger.debug(y[i].shape) logger.debug(y[i].lA_y.shape) Ay = y[i].lA_y Ay = Ay.reshape(*Ay.shape[:2], -1) _tmp_Ay -= torch.norm(Ay, self.dual_norm, -1) * y[i].eps _tmp_center += Ay.bmm( y[i].reshape(-1).unsqueeze(-1).unsqueeze(0).repeat( batch, 1, 1)) elif sign == 1: for i in range(len(y)): Ay = y[i].uA_y Ay = Ay.reshape(*Ay.shape[:2], -1) _tmp_Ay += torch.norm(Ay, self.dual_norm, -1) * y[i].eps _tmp_center += Ay.bmm( y[i].reshape(-1).unsqueeze(-1).unsqueeze(0).repeat( batch, 1, 1)) _tmp_center += Ax.bmm(x.reshape( batch, -1).unsqueeze(-1)) + sum_b.unsqueeze(-1) bound = _tmp_center.squeeze(-1) + sign * torch.norm( Ax, self.dual_norm, -1) * self.eps + _tmp_Ay return bound
def _convert(self, model, global_input): if self.verbose: logger.info('Converting the model...') if not isinstance(global_input, tuple): global_input = (global_input, ) self.num_global_inputs = len(global_input) nodesOP, nodesIO = self._convert_nodes(model, global_input) global_input = tuple([i.to(self.device) for i in global_input]) while True: self._build_graph(nodesOP, nodesIO) self.forward(*global_input) nodesOP, nodesIO, found_complex = self._split_complex( nodesOP, nodesIO) if not found_complex: break for node in self.nodes: for p in list(node.named_parameters()): if node.ori_name not in self._parameters: # For parameter or input nodes, use their original name directly self._parameters[node.ori_name] = p[1] logger.debug('NodesOP:') for node in nodesOP: logger.debug('{}'.format(node._replace(param=None))) logger.debug('NodesIO') for node in nodesIO: logger.debug('{}'.format(node._replace(param=None))) if self.verbose: logger.info('Model converted to support bounds')
def _convert(self, model, global_input): if self.verbose: logger.info('Converting the model...') if not isinstance(global_input, tuple): global_input = (global_input, ) self.num_global_inputs = len(global_input) nodesOP, nodesIO = self._convert_nodes(model, global_input) global_input = tuple([i.to(self.device) for i in global_input]) while True: self._build_graph(nodesOP, nodesIO) self.forward(*global_input) # running means/vars changed nodesOP, nodesIO, found_complex = self._split_complex( nodesOP, nodesIO) if not found_complex: break self._get_node_name_map() # load self.ori_state_dict again to avoid the running means/vars changed during forward() self.load_state_dict(self.ori_state_dict) model.load_state_dict(self.ori_state_dict) delattr(self, 'ori_state_dict') logger.debug('NodesOP:') for node in nodesOP: logger.debug('{}'.format(node._replace(param=None))) logger.debug('NodesIO') for node in nodesIO: logger.debug('{}'.format(node._replace(param=None))) if self.verbose: logger.info('Model converted to support bounds')
def _convert(self, model, global_input): if self.verbose: logger.info('Converting the model...') if not isinstance(global_input, tuple): global_input = (global_input, ) self.num_global_inputs = len(global_input) self.device = global_input[0].device nodesOP, nodesIO = self._convert_nodes(model, global_input) while True: self._build_graph(nodesOP, nodesIO) self.forward(*global_input) nodesOP, nodesIO, found_complex = self._split_complex( nodesOP, nodesIO) if not found_complex: break for node in self.nodes: for p in list(node.named_parameters()): self.register_parameter('{}/{}'.format(node.name, p[0]), p[1]) logger.debug('NodesOP:') for node in nodesOP: logger.debug('{}'.format(node._replace(param=None))) logger.debug('NodesIO') for node in nodesIO: logger.debug('{}'.format(node._replace(param=None))) if self.verbose: logger.info('Model converted to support bounds')
def _convert_nodes(self, model, global_input): global_input_cpu = tuple([i.to("cpu") for i in list(global_input)]) model.train() model.to('cpu') nodesOP, nodesIO = get_graph_params(model, global_input_cpu) model.to(self.device) for i in range(0, len(nodesIO)): if nodesIO[i].param is not None: nodesIO[i] = nodesIO[i]._replace( param=nodesIO[i].param.to(self.device)) for n in range(len(nodesOP)): attr = nodesOP[n].attr inputs = self._get_node_input(nodesOP, nodesIO, nodesOP[n]) if nodesOP[n].op in bound_op_map: if nodesOP[n].op == 'onnx::BatchNormalization': # BatchNormalization node needs model.training flag to set running mean and vars nodesOP[n] = nodesOP[n]._replace( bound_node=bound_op_map[nodesOP[n].op] (nodesOP[n].inputs, nodesOP[n].name, attr, inputs, nodesOP[n].output_index, self.device, model.training)) else: nodesOP[n] = nodesOP[n]._replace( bound_node=bound_op_map[nodesOP[n].op]( nodesOP[n].inputs, nodesOP[n].name, attr, inputs, nodesOP[n].output_index, self.device)) else: print(nodesOP[n]) raise NotImplementedError('Unsupported operation {}'.format( nodesOP[n].op)) if self.verbose: logger.debug( 'Convert complete for {} with operation: {}'.format( nodesOP[n].name, nodesOP[n].op)) for i in range(0, len(global_input)): nodesIO[i] = nodesIO[i]._replace(param=global_input[i], bound_node=BoundInput( nodesIO[i].inputs, nodesIO[i].name, value=global_input[i])) nodesIO[i].bound_node.method = 'forward' for i in range(len(global_input), len(nodesIO)): nodesIO[i] = nodesIO[i]._replace(bound_node=BoundParams( nodesIO[i].inputs, nodesIO[i].name, value=nodesIO[i].param)) nodesIO[i].bound_node.method = 'forward' return nodesOP, nodesIO
def parse_module(module, inputs, param_exclude=".*AuxLogits.*", param_include=None): params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include) if version.parse(torch.__version__) < version.parse("1.4.0"): trace, out = torch.jit.get_trace_graph(module, inputs) torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) trace_graph = trace.graph() else: # _get_trace_graph becomes an internal function in version >= 1.4.0 trace, out = torch.jit._get_trace_graph(module, inputs) # this is not present in older torch from torch.onnx.symbolic_helper import _set_opset_version if version.parse(torch.__version__) < version.parse("1.5.0"): _set_opset_version(11) else: _set_opset_version(12) trace_graph = torch.onnx._optimize_trace( trace, torch.onnx.OperatorExportTypes.ONNX) logger.debug('trace_graph: {}'.format(trace_graph)) if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0: print("Graph before ONNX convertion:") print(trace) print("ONNX graph:") print(trace_graph) if not isinstance(inputs, tuple): inputs = (inputs, ) nodesOP, nodesIn, nodesOut = parse_graph(trace_graph, tuple(inputs), tuple(params)) for i in range(len(nodesOP)): param_in = OrderedDict() for inp in nodesOP[i].inputs: for n in nodesIn: if inp == n.name: param_in.update({inp: n.param}) nodesOP[i] = nodesOP[i]._replace(param=param_in) template = get_output_template(out) return nodesOP, nodesIn, nodesOut, template
def _backward_general(self, C=None, node=None, root=None, bound_lower=True, bound_upper=True): logger.debug('Backward from {} {}'.format(node.name, node)) degree_out = {} for l in self.nodes: l.bounded = True l.lA = l.uA = None degree_out[l.name] = 0 queue = [node] while len(queue) > 0: l = queue[0] queue = queue[1:] for l_pre in l.input_name: degree_out[l_pre] += 1 if self.node_dict[l_pre].bounded: self.node_dict[l_pre].bounded = False queue.append(self.node_dict[l_pre]) node.bounded = True node.lA = C if bound_lower else None node.uA = C if bound_upper else None lb = ub = torch.tensor(0.).to(C.device) queue = [node] while len(queue) > 0: l = queue[0] # backward from l queue = queue[1:] l.bounded = True if l.name in self.root_name or l == root: continue for l_pre in l.input_name: _l = self.node_dict[l_pre] degree_out[l_pre] -= 1 if degree_out[l_pre] == 0: queue.append(_l) if l.lA is not None or l.uA is not None: def add_bound(node, lA, uA): if lA is not None: node.lA = lA if node.lA is None else (node.lA + lA) if uA is not None: node.uA = uA if node.uA is None else (node.uA + uA) input_nodes = [ self.node_dict[l_name] for l_name in l.input_name ] A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *input_nodes) lb = lb + lower_b ub = ub + upper_b for i, l_pre in enumerate(l.input_name): _l = self.node_dict[l_pre] add_bound(_l, lA=A[i][0], uA=A[i][1]) batch_size = C.shape[0] output_shape = node.forward_value.shape[1:] if node.forward_value.contiguous().view(batch_size, -1).shape[1] != C.shape[1]: output_shape = [-1] for i in range(len(root)): if root[i].lA is None and root[i].uA is None: continue logger.debug('concretize node: {} shape: {}'.format( root[i], root[i].lA.shape)) lA = root[i].lA.reshape(batch_size, root[i].lA.shape[1], -1) if bound_lower else None uA = root[i].uA.reshape(batch_size, root[i].uA.shape[1], -1) if bound_upper else None if root[i].perturbation is not None: if isinstance(root[i], BoundParams): # add batch_size dim for weights node lb = lb + root[i].perturbation.concretize( root[i].center.unsqueeze(0).repeat( ([batch_size] + [1] * len(root[i].center.shape))), lA, sign=-1, aux=root[i].aux) if bound_lower else None ub = ub + root[i].perturbation.concretize( root[i].center.unsqueeze(0).repeat( ([batch_size] + [1] * len(root[i].center.shape))), uA, sign=+1, aux=root[i].aux) if bound_upper else None else: lb = lb + root[i].perturbation.concretize( root[i].center, lA, sign=-1, aux=root[i].aux) if bound_lower else None ub = ub + root[i].perturbation.concretize( root[i].center, uA, sign=+1, aux=root[i].aux) if bound_upper else None elif i < self.num_global_inputs: lb = lb + root[i].lA.reshape( batch_size, root[i].lA.shape[1], -1).bmm( root[i].forward_value.view(batch_size, -1, 1)).squeeze( -1) if bound_lower else None ub = ub + root[i].uA.reshape( batch_size, root[i].uA.shape[1], -1).bmm( root[i].forward_value.view(batch_size, -1, 1)).squeeze( -1) if bound_upper else None else: lb = lb + root[i].lA.reshape( batch_size, root[i].lA.shape[1], -1).matmul(root[i].forward_value.view( -1, 1)).squeeze(-1) if bound_lower else None ub = ub + root[i].uA.reshape( batch_size, root[i].uA.shape[1], -1).matmul(root[i].forward_value.view( -1, 1)).squeeze(-1) if bound_upper else None node.lower = lb.view(batch_size, * output_shape) if bound_lower else None node.upper = ub.view(batch_size, * output_shape) if bound_upper else None return node.lower, node.upper
def _backward_general(self, C=None, node=None, root=None, bound_lower=True, bound_upper=True, return_A=False, average_A=False): _print_time = False degree_out = {} for l in self._modules.values(): l.bounded = True l.lA = l.uA = None degree_out[l.name] = 0 queue = [node] while len(queue) > 0: l = queue[0] queue = queue[1:] for l_pre in l.input_name: degree_out[l_pre] += 1 # calculate the out degree if self._modules[l_pre].bounded: self._modules[l_pre].bounded = False queue.append(self._modules[l_pre]) node.bounded = True node.lA = C if bound_lower else None node.uA = C if bound_upper else None lb = ub = torch.tensor(0.).to(C.device) queue = [node] while len(queue) > 0: l = queue[0] # backward from l queue = queue[1:] l.bounded = True if l.name in self.root_name or l == root: continue for l_pre in l.input_name: # if all the succeeds are done, then we can turn to this node in the next iteration. _l = self._modules[l_pre] degree_out[l_pre] -= 1 if degree_out[l_pre] == 0: queue.append(_l) if l.lA is not None or l.uA is not None: def add_bound(node, lA, uA): if lA is not None: node.lA = lA if node.lA is None else (node.lA + lA) if uA is not None: node.uA = uA if node.uA is None else (node.uA + uA) input_nodes = [ self._modules[l_name] for l_name in l.input_name ] if _print_time: start_time = time.time() logger.debug('Backward from {} to {}, {}'.format( node.name, l.name, l)) A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *input_nodes) if _print_time: time_elapsed = time.time() - start_time if time_elapsed > 1e-3: print(l, time_elapsed) lb = lb + lower_b ub = ub + upper_b for i, l_pre in enumerate(l.input_name): _l = self._modules[l_pre] add_bound(_l, lA=A[i][0], uA=A[i][1]) batch_size = C.shape[0] output_shape = node.default_shape[1:] if np.prod(node.default_shape[1:]) != C.shape[1]: output_shape = [-1] if return_A: # return A matrix as a dict: {node.name: [A_lower, A_upper]} A_dict = {'bias': [lb, ub]} for i in range(len(root)): if root[i].lA is None and root[i].uA is None: continue A_dict.update({root[i].name: [root[i].lA, root[i].uA]}) for i in range(len(root)): if root[i].lA is None and root[i].uA is None: continue if average_A and isinstance(root[i], BoundParams): A_shape = root[i].lA.shape if bound_lower else root[i].uA.shape lA = root[i].lA.mean(0, keepdim=True).repeat( A_shape[0], *[1] * len(A_shape[1:])) if bound_lower else None uA = root[i].uA.mean(0, keepdim=True).repeat( A_shape[0], *[1] * len(A_shape[1:])) if bound_upper else None else: lA = root[i].lA uA = root[i].uA if not isinstance(root[i].lA, eyeC): lA = root[i].lA.reshape(batch_size, root[i].lA.shape[1], -1) if bound_lower else None if not isinstance(root[i].uA, eyeC): uA = root[i].uA.reshape(batch_size, root[i].uA.shape[1], -1) if bound_upper else None if root[i].perturbation is not None: if isinstance(root[i], BoundParams): # add batch_size dim for weights node lb = lb + root[i].perturbation.concretize( root[i].center.unsqueeze(0), lA, sign=-1, aux=root[i].aux) if bound_lower else None ub = ub + root[i].perturbation.concretize( root[i].center.unsqueeze(0), uA, sign=+1, aux=root[i].aux) if bound_upper else None else: lb = lb + root[i].perturbation.concretize( root[i].center, lA, sign=-1, aux=root[i].aux) if bound_lower else None ub = ub + root[i].perturbation.concretize( root[i].center, uA, sign=+1, aux=root[i].aux) if bound_upper else None elif i < self.num_global_inputs: if not isinstance(lA, eyeC): lb = lb + lA.bmm(root[i].value.view(batch_size, -1, 1) ).squeeze(-1) if bound_lower else None else: lb = lb + root[i].value.view(batch_size, -1) if bound_lower else None if not isinstance(uA, eyeC): ub = ub + uA.bmm(root[i].value.view(batch_size, -1, 1) ).squeeze(-1) if bound_upper else None else: ub = ub + root[i].value.view(batch_size, -1) if bound_upper else None else: if not isinstance(lA, eyeC): lb = lb + lA.matmul(root[i].param.view( -1, 1)).squeeze(-1) if bound_lower else None else: lb = lb + root[i].param.view(1, -1) if bound_lower else None if not isinstance(uA, eyeC): ub = ub + uA.matmul(root[i].param.view( -1, 1)).squeeze(-1) if bound_upper else None else: ub = ub + root[i].param.view(1, -1) if bound_upper else None node.lower = lb.view(batch_size, * output_shape) if bound_lower else None node.upper = ub.view(batch_size, * output_shape) if bound_upper else None if return_A: return node.lower, node.upper, A_dict return node.lower, node.upper
def _convert_nodes(self, model, global_input): global_input_cpu = tuple([i.to('cpu') for i in list(global_input)]) model.train() model.to('cpu') nodesOP, nodesIO = get_graph_params(model, global_input_cpu) model.to(self.device) for i in range(0, len(nodesIO)): if nodesIO[i].param is not None: nodesIO[i] = nodesIO[i]._replace( param=nodesIO[i].param.to(self.device)) # FIXME: better way to handle buffers, do not hard-code it for BN! # Other nodes can also have buffers. bn_nodes = [] for n in range(len(nodesOP)): if nodesOP[n].op == 'onnx::BatchNormalization': bn_nodes.extend( nodesOP[n].inputs[3:] ) # collect names of running_mean and running_var # Convert input nodes and parameters. for i in range(0, len(global_input)): nodesIO[i] = nodesIO[i]._replace( param=global_input[i], bound_node=BoundInput(nodesIO[i].inputs, nodesIO[i].name, nodesIO[i].ori_name, value=global_input[i], perturbation=nodesIO[i].perturbation)) for i in range(len(global_input), len(nodesIO)): if nodesIO[i].name in bn_nodes: nodesIO[i] = nodesIO[i]._replace(bound_node=BoundBuffers( nodesIO[i].inputs, nodesIO[i].name, nodesIO[i].ori_name, value=nodesIO[i].param, perturbation=nodesIO[i].perturbation)) else: nodesIO[i] = nodesIO[i]._replace(bound_node=BoundParams( nodesIO[i].inputs, nodesIO[i].name, nodesIO[i].ori_name, value=nodesIO[i].param, perturbation=nodesIO[i].perturbation)) # Convert other operation nodes. for n in range(len(nodesOP)): attr = nodesOP[n].attr inputs, ori_names = self._get_node_input(nodesOP, nodesIO, nodesOP[n]) if nodesOP[n].op in bound_op_map: if nodesOP[n].op == 'onnx::BatchNormalization': # BatchNormalization node needs model.training flag to set running mean and vars # set training=False to avoid wrongly updating running mean/vars during bound wrapper nodesOP[n] = nodesOP[n]._replace( bound_node=bound_op_map[nodesOP[n].op] (nodesOP[n].inputs, nodesOP[n].name, None, attr, inputs, nodesOP[n].output_index, self.device, False)) elif nodesOP[n].op in [ 'onnx::Relu', 'onnx::LeakyRelu', 'onnx::Exp' ]: nodesOP[n] = nodesOP[n]._replace( bound_node=bound_op_map[nodesOP[n].op] (nodesOP[n].inputs, nodesOP[n].name, None, attr, inputs, nodesOP[n].output_index, self.device, self.bound_opts)) else: nodesOP[n] = nodesOP[n]._replace( bound_node=bound_op_map[nodesOP[n].op]( nodesOP[n].inputs, nodesOP[n].name, None, attr, inputs, nodesOP[n].output_index, self.device)) else: print(nodesOP[n]) raise NotImplementedError('Unsupported operation {}'.format( nodesOP[n].op)) if self.verbose: logger.debug( 'Convert complete for {} with operation: {}'.format( nodesOP[n].name, nodesOP[n].op)) return nodesOP, nodesIO
def weights_backward_general(self, norm=np.inf, x=None, eps=None, C=None, ptb=None, node=None, root=None): assert (len(root) == 1) root = root[0] torch.cuda.empty_cache() logger.debug('Backward from {} {}'.format(node.name, node)) degree_out = {} for l in self.nodes: l.bounded = True l.lA = l.uA = None degree_out[l.name] = 0 queue = [node] while len(queue) > 0: l = queue[0] queue = queue[1:] for l_pre in l.input_name: degree_out[l_pre] += 1 if self.node_dict[l_pre].bounded: self.node_dict[l_pre].bounded = False queue.append(self.node_dict[l_pre]) node.bounded = True node.uA = C node.lA = C upper_sum_b = lower_sum_b = torch.tensor(0.).to(C.device) queue = [node] nodes_perturb_list = [] while len(queue) > 0: l = queue[0] queue = queue[1:] l.bounded = True if l in self.root_name or l == root: continue for l_pre in l.input_name: _l = self.node_dict[l_pre] degree_out[l_pre] -= 1 if degree_out[l_pre] == 0: queue.append(_l) if l.uA is not None: def add_bound(node, uA, lA): node.uA = uA if node.uA is None else (node.uA + uA) node.lA = lA if node.lA is None else (node.lA + lA) logger.debug('Backward at {} {}'.format(l.name, l)) if len(l.input_name) == 1: input_node = self.node_dict[l.input_name[0]] if hasattr(l, 'nonlinear') and l.nonlinear is True: lA, lower_b, uA, upper_b = l.bound_backward( l.lA, l.uA, input_node) A = [(uA, lA)] else: [(lA_x, uA_x), (lA_y, uA_y) ], upper_b, lower_b = l.two_bounds_backward( l.lA, l.uA, input_node, l) A = [(lA_x, uA_x)] # y is weights, x is input l.weight.lA_y, l.weight.uA_y = lA_y, uA_y nodes_perturb_list.append(l.weight) else: A, lower_b, upper_b = l.bound_backward(l.lA, l.uA) upper_sum_b = upper_sum_b + upper_b lower_sum_b = lower_sum_b + lower_b for i, l_pre in enumerate(l.input_name): _l = self.node_dict[l_pre] add_bound(_l, uA=A[i][0], lA=A[i][1]) batch_size = C.shape[0] output_shape = node.forward_value.shape[1:] if node.forward_value.contiguous().view(batch_size, -1).shape[1] != C.shape[1]: output_shape = [-1] if node.from_input: lb = ptb.concretize_2bounds(x, root.lA, lower_sum_b, sign=-1, y=nodes_perturb_list) ub = ptb.concretize_2bounds(x, root.uA, upper_sum_b, sign=+1, y=nodes_perturb_list) else: lb, ub = lower_sum_b.reshape(-1), upper_sum_b.reshape(-1) return lb.view(batch_size, *output_shape), ub.view(batch_size, *output_shape)
def _backward_general(self, norm=np.inf, x=None, C=None, ptb=None, node=None, root=None): logger.debug('Backward from {} {}'.format(node.name, node)) degree_out = {} for l in self.nodes: l.bounded = True l.lA = l.uA = None degree_out[l.name] = 0 queue = [node] while len(queue) > 0: l = queue[0] queue = queue[1:] for l_pre in l.input_name: degree_out[l_pre] += 1 if self.node_dict[l_pre].bounded: self.node_dict[l_pre].bounded = False queue.append(self.node_dict[l_pre]) node.bounded = True node.lA = node.uA = C lb = ub = torch.tensor(0.).to(C.device) queue = [node] while len(queue) > 0: l = queue[0] # backward from l queue = queue[1:] l.bounded = True if l.name in self.root_name or l == root: continue for l_pre in l.input_name: _l = self.node_dict[l_pre] degree_out[l_pre] -= 1 if degree_out[l_pre] == 0: queue.append(_l) if l.uA is not None: def add_bound(node, lA, uA): node.lA = lA if node.lA is None else (node.lA + lA) node.uA = uA if node.uA is None else (node.uA + uA) logger.debug('Backward at {} {}'.format(l.name, l)) input_nodes = [ self.node_dict[l_name] for l_name in l.input_name ] if len(l.input_name) == 1: lA, lower_b, uA, upper_b = l.bound_backward( l.lA, l.uA, *input_nodes) A = [(lA, uA)] else: A, lower_b, upper_b = l.bound_backward( l.lA, l.uA, *input_nodes) ub = ub + upper_b lb = lb + lower_b for i, l_pre in enumerate(l.input_name): _l = self.node_dict[l_pre] add_bound(_l, lA=A[i][0], uA=A[i][1]) batch_size = C.shape[0] output_shape = node.forward_value.shape[1:] if node.forward_value.contiguous().view(batch_size, -1).shape[1] != C.shape[1]: output_shape = [-1] for r in root: if r.lA is None: continue if isinstance(r.linear, LinearBound): uA = r.uA.reshape(batch_size, r.uA.shape[1], -1).matmul( r.linear.uw.view(batch_size, r.linear.uw.shape[1], -1).transpose(1, 2)) ub = ub + r.uA.reshape(batch_size, r.uA.shape[1], -1).matmul( r.linear.ub.view(batch_size, -1, 1)).squeeze(-1) lA = r.lA.reshape(batch_size, r.lA.shape[1], -1).matmul( r.linear.lw.view(batch_size, r.linear.lw.shape[1], -1).transpose(1, 2)) lb = lb + r.lA.reshape(batch_size, r.lA.shape[1], -1).matmul( r.linear.lb.view(batch_size, -1, 1)).squeeze(-1) lb = lb + ptb.concretize(x, lA, torch.zeros_like(lb), sign=-1) ub = ub + ptb.concretize(x, uA, torch.zeros_like(ub), sign=+1) else: lb = lb + r.lA.reshape(batch_size, r.lA.shape[1], -1).matmul( r.forward_value.view(batch_size, -1, 1)).squeeze(-1) ub = ub + r.uA.reshape(batch_size, r.uA.shape[1], -1).matmul( r.forward_value.view(batch_size, -1, 1)).squeeze(-1) node.lower = lb.view(batch_size, *output_shape) node.upper = ub.view(batch_size, *output_shape) return node.lower, node.upper
def __init__(self, norm, eps): self.norm = norm self.eps = eps # eps of input x self.dual_norm = 1 if (norm == np.inf) else (np.float64(1.0) / (1 - 1.0 / self.norm)) logger.debug('Using l{} norm to concretize'.format(self.dual_norm))