def __init__(self, in_channel=3, out_dim=10, pooling=MaxPool2dInterval): super(IntervalCNN, self).__init__() # self.input = Conv2dInterval(in_channel, 32, kernel_size=3, stride=1, padding=1, input_layer=True) self.c1 = nn.Sequential( Conv2dInterval(in_channel, 32, kernel_size=3, stride=1, padding=1, input_layer=True), nn.ReLU(), Conv2dInterval(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), Conv2dInterval(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25)) self.c2 = nn.Sequential( Conv2dInterval(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), Conv2dInterval(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25)) self.c3 = nn.Sequential( Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), pooling(2, stride=2, padding=1), IntervalDropout(0.25)) self.fc1 = nn.Sequential(LinearInterval(128 * 5 * 5, 256), nn.ReLU()) self.last = LinearInterval(256, out_dim) self.a = nn.Parameter(torch.Tensor([1, 1, 1, 1, 1, 1, 1, 2, 0]), requires_grad=True) self.a = nn.Parameter(torch.zeros(9), requires_grad=True) self.e = torch.zeros(9) self.bounds = None
def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256): super(IntervalMLP, self).__init__() self.in_dim = in_channel * img_sz * img_sz self.fc1 = LinearInterval(self.in_dim, hidden_dim, input_layer=True) self.fc2 = LinearInterval(hidden_dim, hidden_dim) # Subject to be replaced dependent on task self.last = LinearInterval(hidden_dim, out_dim) self.a = nn.Parameter(torch.Tensor([2, 1, 0]), requires_grad=True) self.e = torch.zeros(3) self.bounds = None
def __init__(self, eps=0): super().__init__() self.conv1 = Conv2dInterval(3, 64, 3, 1, input_layer=True) self.conv2 = Conv2dInterval(64, 64, 3, 1) self.conv3 = Conv2dInterval(64, 128, 3, 2) self.conv4 = Conv2dInterval(128, 128, 3, 1) self.conv5 = Conv2dInterval(128, 128, 3, 1) self.fc1 = LinearInterval(128 * 9 * 9, 200) self.last = LinearInterval(200, 10) self.a = nn.Parameter(torch.zeros(7), requires_grad=True) self.e = None self.eps = eps self.bounds = None
def create_model(self): cfg = self.config # Define the backbone (MLP, LeNet, VGG, ResNet ... etc) of model model = models.__dict__[cfg['model_type']].__dict__[ cfg['model_name']]() # Apply network surgery to the backbone # Create the heads for tasks (It can be single task or multi-task) n_feat = model.last.in_features # The output of the model will be a dict: {task_name1:output1, task_name2:output2 ...} # For a single-headed model the output will be {'All':output} model.last = nn.ModuleDict() for task, out_dim in cfg['out_dim'].items(): model.last[task] = LinearInterval(n_feat, out_dim) # Redefine the task-dependent function def new_logits(self, x): outputs = {} for task, func in self.last.items(): outputs[task] = func(x) return outputs # Replace the task-dependent function model.logits = MethodType(new_logits, model) # Load pre-trained weights if cfg['model_weights'] is not None: print('=> Load model weights:', cfg['model_weights']) model_state = torch.load( cfg['model_weights'], map_location=lambda storage, loc: storage) # Load to CPU. model.load_state_dict(model_state) print('=> Load Done') return model
class IntervalMLP(nn.Module): def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256): super(IntervalMLP, self).__init__() self.in_dim = in_channel*img_sz*img_sz self.fc1 = LinearInterval(self.in_dim, hidden_dim, input_layer=True) self.fc2 = LinearInterval(hidden_dim, hidden_dim) # Subject to be replaced dependent on task self.last = LinearInterval(hidden_dim, out_dim) self.a = nn.Parameter(torch.zeros(3), requires_grad=True) self.e = None self.bounds = None def save_bounds(self, x): s = x.size(1) // 3 self.bounds = x[:, s:2*s], x[:, 2*s:] def calc_eps(self, r): exp = self.a.exp() self.e = r * exp / exp.sum() def print_eps(self): e1 = self.fc1.eps.detach() e2 = self.fc2.eps.detach() print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}") print(f"sum: {e2.sum()} - mean: {e2.mean()} - std: {e2.std()}") # print(100 * "=") # print(e1) # print(100 * "+") # print(e2) # print(100 * "+") for name, layer in self.last.items(): l = layer.eps.detach() print(f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}") # print(100 * "+") # print(l) # print(100 * "+") def reset_importance(self): self.fc1.rest_importance() self.fc2.rest_importance() for _, layer in self.last.items(): layer.rest_importance() def set_eps(self, eps, trainable=False): if trainable: self.calc_eps(eps) self.fc1.calc_eps(self.e[0]) self.fc2.calc_eps(self.e[1]) for _, layer in self.last.items(): layer.calc_eps(self.e[2]) else: self.fc1.calc_eps(eps) self.fc2.calc_eps(eps) for _, layer in self.last.items(): layer.calc_eps(eps) def features(self, x): x = x.view(-1, self.in_dim) x = f.relu(self.fc1(x)) x = f.relu(self.fc2(x)) self.save_bounds(x) return x def logits(self, x): return self.last(x) def forward(self, x): x = self.features(x) x = self.logits(x) return {k: v[:, :v.size(1)//3] for k, v in x.items()}
class Large(nn.Module): def __init__(self, eps=0): super().__init__() self.conv1 = Conv2dInterval(3, 64, 3, 1, input_layer=True) self.conv2 = Conv2dInterval(64, 64, 3, 1) self.conv3 = Conv2dInterval(64, 128, 3, 2) self.conv4 = Conv2dInterval(128, 128, 3, 1) self.conv5 = Conv2dInterval(128, 128, 3, 1) self.fc1 = LinearInterval(128 * 9 * 9, 200) self.last = LinearInterval(200, 10) self.a = nn.Parameter(torch.zeros(7), requires_grad=True) self.e = None self.eps = eps self.bounds = None def save_bounds(self, x): s = x.size(1) // 3 self.bounds = x[:, s:2 * s], x[:, 2 * s:] def calc_eps(self, r): exp = self.a.exp() self.e = r * exp / exp.sum() def print_eps(self): for c in (self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.fc1): e1 = c.eps.detach() print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}") for name, layer in self.last.items(): l = layer.eps.detach() print( f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}" ) def reset_importance(self): pass # self.conv1.reset_importance() # self.conv2.reset_importance() # self.conv3.reset_importance() # self.conv4.reset_importance() # self.conv5.reset_importance() # self.fc1.reset_importance() # for _, layer in self.last.items(): # layer.reset_importance() def set_eps(self, eps, trainable=False): if trainable: self.calc_eps(eps) self.conv1.calc_eps(self.e[0]) self.conv2.calc_eps(self.e[1]) self.conv3.calc_eps(self.e[2]) self.conv4.calc_eps(self.e[3]) self.conv5.calc_eps(self.e[4]) self.fc1.calc_eps(self.e[5]) for _, layer in self.last.items(): layer.calc_eps(self.e[6]) else: self.conv1.calc_eps(eps) self.conv2.calc_eps(eps) self.conv3.calc_eps(eps) self.conv4.calc_eps(eps) self.conv5.calc_eps(eps) self.fc1.calc_eps(eps) for _, layer in self.last.items(): layer.calc_eps(eps) def features(self, x): x = f.relu(self.conv1(x)) x = f.relu(self.conv2(x)) x = f.relu(self.conv3(x)) x = f.relu(self.conv4(x)) x = f.relu(self.conv5(x)) x = torch.flatten(x, 1) x = f.relu(self.fc1(x)) self.save_bounds(x) return x def logits(self, x): return self.last(x) def forward(self, x): x = self.features(x) x = self.logits(x) return {k: v[:, :v.size(1) // 3] for k, v in x.items()}
class IntervalCNN(nn.Module): def __init__(self, in_channel=3, out_dim=10, pooling=MaxPool2dInterval): super(IntervalCNN, self).__init__() self.input = Conv2dInterval(in_channel, 32, kernel_size=3, stride=1, padding=1, input_layer=True) self.c1 = nn.Sequential( Conv2dInterval(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), Conv2dInterval(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25)) self.c2 = nn.Sequential( Conv2dInterval(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), Conv2dInterval(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25)) self.c3 = nn.Sequential( Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), pooling(2, stride=2, padding=1), IntervalDropout(0.25)) self.fc1 = nn.Sequential(LinearInterval(128 * 5 * 5, 256), nn.ReLU()) self.last = LinearInterval(256, out_dim) self.a = nn.Parameter(torch.zeros(9), requires_grad=True) self.e = None self.bounds = None def save_bounds(self, x): s = x.size(1) // 3 self.bounds = x[:, s:2 * s], x[:, 2 * s:] def print_eps(self): e = self.input.eps.detach() print(f"sum: {e.sum()} - mean: {e.mean()} - std: {e.std()}") print(f"min: {e.min()} - max: {e.max()}") for c in (self.c1, self.c2, self.c3): e1 = c[0].eps.detach() e2 = c[2].eps.detach() print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}") print(f"min: {e1.min()} - max: {e1.max()}") print(f"sum: {e2.sum()} - mean: {e2.mean()} - std: {e2.std()}") print(f"min: {e2.min()} - max: {e2.max()}") e = self.fc1[0].eps.detach() print(f"sum: {e.sum()} - mean: {e.mean()} - std: {e.std()}") print(f"min: {e.min()} - max: {e.max()}") for name, layer in self.last.items(): l = layer.eps.detach() print( f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}" ) print(f"min: {l.min()} - max: {l.max()}") def calc_eps(self, r): exp = self.a.exp() self.e = r * exp / exp.sum() def reset_importance(self): self.input.rest_importance() self.c1[0].rest_importance() self.c1[2].rest_importance() self.c2[0].rest_importance() self.c2[2].rest_importance() self.c3[0].rest_importance() self.c3[2].rest_importance() self.fc1[0].rest_importance() for _, layer in self.last.items(): layer.rest_importance() def set_eps(self, eps, trainable=False): if trainable: self.calc_eps(eps) self.input.calc_eps(self.e[0]) self.c1[0].calc_eps(self.e[1]) self.c1[2].calc_eps(self.e[2]) self.c2[0].calc_eps(self.e[3]) self.c2[2].calc_eps(self.e[4]) self.c3[0].calc_eps(self.e[5]) self.c3[2].calc_eps(self.e[6]) self.fc1[0].calc_eps(self.e[7]) for _, layer in self.last.items(): layer.calc_eps(self.e[8]) else: self.input.calc_eps(eps) self.c1[0].calc_eps(eps) self.c1[2].calc_eps(eps) self.c2[0].calc_eps(eps) self.c2[2].calc_eps(eps) self.c3[0].calc_eps(eps) self.c3[2].calc_eps(eps) self.fc1[0].calc_eps(eps) for _, layer in self.last.items(): layer.calc_eps(eps) def features(self, x): x = self.input(x) x = self.c1(x) x = self.c2(x) x = self.c3(x) x = torch.flatten(x, 1) x = self.fc1(x) self.save_bounds(x) return x def logits(self, x): return self.last(x) def forward(self, x): x = self.features(x) x = self.logits(x) return {k: v[:, :v.size(1) // 3] for k, v in x.items()}