Beispiel #1
0
    def convert(linear_layer, bound_opts=None):
        
        if 'SubnetLinear' in (str(linear_layer.__class__.__name__)):
            # print("subnet conv converted")
            l = BoundLinear(linear_layer.in_features,
                linear_layer.out_features,
                linear_layer.bias is not None,
                bound_opts)
            l.layer = linear_layer

            adj = GetSubnet.apply(linear_layer.popup_scores.abs(),
                        linear_layer.k)
            l.layer.w = l.layer.weight * adj
            # l.weight = linear_layer.weight
            l.weight = linear_layer.weight
            l.bias = linear_layer.bias
            return l

        l = BoundLinear(linear_layer.in_features,
                linear_layer.out_features,\
                linear_layer.bias is not None,
                bound_opts)
        # l.weight.copy_(linear_layer.weight.data)
        # l.bias.data.copy_(linear_layer.bias.data)
        l.weight = linear_layer.weight
        l.bias = linear_layer.bias

        return l
Beispiel #2
0
    def forward(self, input):
        
        if 'SubnetLinear' in (str(self.layer.__class__.__name__)):
            adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)
            self.layer.w = self.layer.weight * adj
            output = F.linear(input, self.layer.w, self.layer.bias)
        else:
            output = super(BoundLinear, self).forward(input)

        return output
Beispiel #3
0
 def interval_propagate(self, norm, h_U, h_L, eps):
     if 'SubnetConv' in (str(self.layer.__class__.__name__)):
         adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)
         self.layer.w = self.layer.weight * adj
         if norm == np.inf:
             mid = (h_U + h_L) / 2.0
             diff = (h_U - h_L) / 2.0
             weight_abs = self.layer.w.abs()
             deviation = F.conv2d(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups)
         else:
             # L2 norm
             mid = h_U
             # logger.debug('mid %s', mid.size())
             # TODO: consider padding here?
             deviation = torch.mul(self.layer.w, self.layer.w).sum((1,2,3)).sqrt() * eps
             # logger.debug('weight %s', self.layer.w.size())
             # logger.debug('deviation %s', deviation.size())
             deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
             # logger.debug('unsqueezed deviation %s', deviation.size())
         center = F.conv2d(mid, self.layer.w, self.bias, self.stride, self.padding, self.dilation, self.groups)
         # logger.debug('center %s', center.size())
         upper = center + deviation
         lower = center - deviation
         return np.inf, upper, lower, 0, 0, 0, 0
     else:
         if norm == np.inf:
             mid = (h_U + h_L) / 2.0
             diff = (h_U - h_L) / 2.0
             weight_abs = self.weight.abs()
             deviation = F.conv2d(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups)
         else:
             # L2 norm
             mid = h_U
             logger.debug('mid %s', mid.size())
             # TODO: consider padding here?
             deviation = torch.mul(self.weight, self.weight).sum((1,2,3)).sqrt() * eps
             logger.debug('weight %s', self.weight.size())
             logger.debug('deviation %s', deviation.size())
             deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
             logger.debug('unsqueezed deviation %s', deviation.size())
         center = F.conv2d(mid, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
         logger.debug('center %s', center.size())
         upper = center + deviation
         lower = center - deviation
         return np.inf, upper, lower, 0, 0, 0, 0
Beispiel #4
0
    def convert(l, bound_opts=None):
        if 'SubnetConv' in (str(l.__class__.__name__)):
            # print("subnet conv converted")
            nl = BoundConv2d(l.in_channels,
                            l.out_channels,
                            l.kernel_size,
                            l.stride,
                            l.padding,
                            l.dilation,
                            l.groups,
                            l.bias is not None,
                            bound_opts
                        )

            nl.layer = l
            adj = GetSubnet.apply(l.popup_scores.abs(), l.k)
            l.w = l.weight * adj
            nl.weight = l.weight
            nl.bias = l.bias
            return nl

        nl = BoundConv2d(l.in_channels,
                        l.out_channels,
                        l.kernel_size,
                        l.stride,
                        l.padding,
                        l.dilation,
                        l.groups,
                        l.bias is not None,
                        bound_opts
                    )
        # nl.weight.data.copy_(l.weight.data)
        # nl.bias.data.copy_(l.bias.data)
        nl.weight = l.weight
        nl.bias = l.bias
        logger.debug(nl.bias.size())
        logger.debug(nl.weight.size())
        return nl
Beispiel #5
0
 def forward(self, input):
     if 'SubnetConv' in (str(self.layer.__class__.__name__)):
         adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)
         self.layer.w = self.layer.weight * adj
         output = F.conv2d(input, 
                         self.layer.w,
                         self.bias,
                         self.stride,
                         self.padding,
                         self.dilation,
                         self.groups
                     )
         self.output_shape = output.size()[1:]
         self.input_shape = input.size()[1:]
         return output
     else:
         output = super(BoundConv2d, self).forward(input)
         # if 'SubnetConv' in (str(self.layer.__class__.__name__)):
         #     adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)
         #     self.layer.w = self.layer.weight * adj
         self.output_shape = output.size()[1:]
         self.input_shape = input.size()[1:]
         return output
Beispiel #6
0
    def interval_propagate(self, norm, h_U, h_L, eps, C = None):
        if 'SubnetLinear' in (str(self.layer.__class__.__name__)):
            adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)

            # Use only the subnetwork in the forward pass.
            self.layer.w = self.layer.weight * adj

            # merge the specification
            if C is not None:
                # after multiplication with C, we have (batch, output_shape, prev_layer_shape)
                # we have batch dimension here because of each example has different C
                weight = C.matmul(self.layer.w)
                bias = C.matmul(self.bias)
            else:
                # weight dimension (this_layer_shape, prev_layer_shape)
                weight = self.layer.w
                bias = self.bias

            if norm == np.inf:
                # Linf norm
                mid = (h_U + h_L) / 2.0
                diff = (h_U - h_L) / 2.0
                weight_abs = weight.abs()
                if C is not None:
                    center = weight.matmul(mid.unsqueeze(-1)) + bias.unsqueeze(-1)
                    deviation = weight_abs.matmul(diff.unsqueeze(-1))
                    # these have an extra (1,) dimension as the last dimension
                    center = center.squeeze(-1)
                    deviation = deviation.squeeze(-1)
                else:
                    # fused multiply-add
                    center = torch.addmm(bias, mid, weight.t())
                    deviation = diff.matmul(weight_abs.t())
            else:
                # L2 norm
                h = h_U # h_U = h_L, and eps is used
                dual_norm = np.float64(1.0) / (1 - 1.0 / norm)
                if C is not None:
                    center = weight.matmul(h.unsqueeze(-1)) + bias.unsqueeze(-1)
                    center = center.squeeze(-1)
                else:
                    center = torch.addmm(bias, h, weight.t())
                deviation = weight.norm(dual_norm, -1) * eps

            upper = center + deviation
            lower = center - deviation
            # output 
            return np.inf, upper, lower, 0, 0, 0, 0
        else:
            # merge the specification
            if C is not None:
                # after multiplication with C, we have (batch, output_shape, prev_layer_shape)
                # we have batch dimension here because of each example has different C
                weight = C.matmul(self.weight)
                bias = C.matmul(self.bias)
            else:
                # weight dimension (this_layer_shape, prev_layer_shape)
                weight = self.weight
                bias = self.bias

            if norm == np.inf:
                # Linf norm
                mid = (h_U + h_L) / 2.0
                diff = (h_U - h_L) / 2.0
                weight_abs = weight.abs()
                if C is not None:
                    center = weight.matmul(mid.unsqueeze(-1)) + bias.unsqueeze(-1)
                    deviation = weight_abs.matmul(diff.unsqueeze(-1))
                    # these have an extra (1,) dimension as the last dimension
                    center = center.squeeze(-1)
                    deviation = deviation.squeeze(-1)
                else:
                    # fused multiply-add
                    center = torch.addmm(bias, mid, weight.t())
                    deviation = diff.matmul(weight_abs.t())
            else:
                # L2 norm
                h = h_U # h_U = h_L, and eps is used
                dual_norm = np.float64(1.0) / (1 - 1.0 / norm)
                if C is not None:
                    center = weight.matmul(h.unsqueeze(-1)) + bias.unsqueeze(-1)
                    center = center.squeeze(-1)
                else:
                    center = torch.addmm(bias, h, weight.t())
                deviation = weight.norm(dual_norm, -1) * eps

            upper = center + deviation
            lower = center - deviation
            # output 
            return np.inf, upper, lower, 0, 0, 0, 0
Beispiel #7
0
	def forward(self, ix):
		adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)
		self.layer.w = self.layer.weight * adj

		if(isinstance(ix, mix_interval)):
			ix.shrink()
			ix.c = F.conv2d(ix.c, self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding, 
						   bias=self.layer.bias)
			ix.idep = F.conv2d(ix.idep, self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding)

			for i in range(len(ix.edep)):
				ix.edep[i] = F.conv2d(ix.edep[i], self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding)
			ix.shape = list(ix.c.shape[1:])
			ix.n = list(ix.c[0].reshape(-1).size())[0]

			c, e = ix.nc, ix.ne

			c = F.conv2d(c, self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding, 
						   bias=self.layer.bias)
			e = F.conv2d(e, self.layer.w.abs(), 
						   stride=self.layer.stride,
						   padding=self.layer.padding)

			ix.nc, ix.ne, ix.nl, ix.nu = c, e, c-e, c+e

			ix.concretize()

			return ix

		if(isinstance(ix, Symbolic_interval)):

			ix.shrink()
			ix.c = F.conv2d(ix.c, self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding, 
						   bias=self.layer.bias)
			ix.idep = F.conv2d(ix.idep, self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding)

			for i in range(len(ix.edep)):
				ix.edep[i] = F.conv2d(ix.edep[i], self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding)
			ix.shape = list(ix.c.shape[1:])
			ix.n = list(ix.c[0].reshape(-1).size())[0]
			ix.concretize()
			return ix


		if(isinstance(ix, Interval)):
			c = ix.c
			e = ix.e
			c = F.conv2d(c, self.layer.w, 
						   stride=self.layer.stride,
						   padding=self.layer.padding, 
						   bias=self.layer.bias)
			e = F.conv2d(e, self.layer.w.abs(), 
						   stride=self.layer.stride,
						   padding=self.layer.padding)
			
			ix.update_lu(c-e, c+e)
			
			return ix
Beispiel #8
0
	def forward(self, ix):
		adj = GetSubnet.apply(self.layer.popup_scores.abs(), self.layer.k)
		self.layer.w = self.layer.weight * adj

		if(isinstance(ix, mix_interval)):
			#print (ix.c.shape, self.layer.weight.shape)
			ix.c = F.linear(ix.c, self.layer.w, bias=self.layer.bias)
			ix.idep = F.linear(ix.idep, self.layer.w)
			for i in range(len(ix.edep)):
				ix.edep[i] = F.linear(ix.edep[i], self.layer.w)
			ix.shape = list(ix.c.shape[1:])
			ix.n = list(ix.c[0].view(-1).size())[0]

			c, e = ix.nc, ix.ne
			if self.wc_matrix is None:
				c = F.linear(c, self.layer.w, bias=self.layer.bias)
				e = F.linear(e, self.layer.w.abs())
				ix.nc, ix.ne, ix.nl, ix.nu = c, e, c-e, c+e
			else:
				weight = self.wc_matrix.matmul(self.layer.w)
				bias = self.wc_matrix.matmul(self.layer.bias)
				
				c = weight.matmul(c.unsqueeze(-1)) + bias.unsqueeze(-1)
				e = weight.abs().matmul(e.unsqueeze(-1))

				c, e = c.squeeze(-1), e.squeeze(-1)

				ix.nc, ix.ne, ix.nl, ix.nu = -c, -e, -c-e, -c+e

			ix.concretize()

			return ix

		if(isinstance(ix, Symbolic_interval)):
			#print (ix.c.shape, self.layer.weight.shape)
			ix.c = F.linear(ix.c, self.layer.w, bias=self.layer.bias)
			ix.idep = F.linear(ix.idep, self.layer.w)
			for i in range(len(ix.edep)):
				ix.edep[i] = F.linear(ix.edep[i], self.layer.w)
			ix.shape = list(ix.c.shape[1:])
			ix.n = list(ix.c[0].view(-1).size())[0]
			ix.concretize()
			return ix

		if(isinstance(ix, Interval)):
			c = ix.c
			e = ix.e

			if self.wc_matrix is None:
				c = F.linear(c, self.layer.w, bias=self.layer.bias)
				e = F.linear(e, self.layer.w.abs())
			else:
				# print(self.wc_matrix)
				# print(self.layer.w)
				weight = self.wc_matrix.matmul(self.layer.w)
				bias = self.wc_matrix.matmul(self.layer.bias)
				
				c = weight.matmul(c.unsqueeze(-1)) + bias.unsqueeze(-1)
				e = weight.abs().matmul(e.unsqueeze(-1))

				c, e = c.squeeze(-1), e.squeeze(-1)

				#print(c.shape, e.shape)
			#print("naive e", e)
			#print("naive c", c)
			ix.update_lu(c-e, c+e)
			
			return ix