Ejemplo n.º 1
0
    def __init__(self, in_channel=3, out_dim=10, pooling=MaxPool2dInterval):
        super(IntervalCNN, self).__init__()

        # self.input = Conv2dInterval(in_channel, 32, kernel_size=3, stride=1, padding=1, input_layer=True)
        self.c1 = nn.Sequential(
            Conv2dInterval(in_channel,
                           32,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           input_layer=True), nn.ReLU(),
            Conv2dInterval(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25))
        self.c2 = nn.Sequential(
            Conv2dInterval(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25))
        self.c3 = nn.Sequential(
            Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=1), IntervalDropout(0.25))
        self.fc1 = nn.Sequential(LinearInterval(128 * 5 * 5, 256), nn.ReLU())
        self.last = LinearInterval(256, out_dim)
        self.a = nn.Parameter(torch.Tensor([1, 1, 1, 1, 1, 1, 1, 2, 0]),
                              requires_grad=True)
        self.a = nn.Parameter(torch.zeros(9), requires_grad=True)
        self.e = torch.zeros(9)
        self.bounds = None
Ejemplo n.º 2
0
 def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
     super(IntervalMLP, self).__init__()
     self.in_dim = in_channel * img_sz * img_sz
     self.fc1 = LinearInterval(self.in_dim, hidden_dim, input_layer=True)
     self.fc2 = LinearInterval(hidden_dim, hidden_dim)
     # Subject to be replaced dependent on task
     self.last = LinearInterval(hidden_dim, out_dim)
     self.a = nn.Parameter(torch.Tensor([2, 1, 0]), requires_grad=True)
     self.e = torch.zeros(3)
     self.bounds = None
Ejemplo n.º 3
0
    def __init__(self, eps=0):
        super().__init__()
        self.conv1 = Conv2dInterval(3, 64, 3, 1, input_layer=True)
        self.conv2 = Conv2dInterval(64, 64, 3, 1)
        self.conv3 = Conv2dInterval(64, 128, 3, 2)
        self.conv4 = Conv2dInterval(128, 128, 3, 1)
        self.conv5 = Conv2dInterval(128, 128, 3, 1)
        self.fc1 = LinearInterval(128 * 9 * 9, 200)
        self.last = LinearInterval(200, 10)

        self.a = nn.Parameter(torch.zeros(7), requires_grad=True)
        self.e = None

        self.eps = eps
        self.bounds = None
Ejemplo n.º 4
0
    def create_model(self):
        cfg = self.config

        # Define the backbone (MLP, LeNet, VGG, ResNet ... etc) of model
        model = models.__dict__[cfg['model_type']].__dict__[
            cfg['model_name']]()

        # Apply network surgery to the backbone
        # Create the heads for tasks (It can be single task or multi-task)
        n_feat = model.last.in_features

        # The output of the model will be a dict: {task_name1:output1, task_name2:output2 ...}
        # For a single-headed model the output will be {'All':output}
        model.last = nn.ModuleDict()
        for task, out_dim in cfg['out_dim'].items():
            model.last[task] = LinearInterval(n_feat, out_dim)

        # Redefine the task-dependent function
        def new_logits(self, x):
            outputs = {}
            for task, func in self.last.items():
                outputs[task] = func(x)
            return outputs

        # Replace the task-dependent function
        model.logits = MethodType(new_logits, model)
        # Load pre-trained weights
        if cfg['model_weights'] is not None:
            print('=> Load model weights:', cfg['model_weights'])
            model_state = torch.load(
                cfg['model_weights'],
                map_location=lambda storage, loc: storage)  # Load to CPU.
            model.load_state_dict(model_state)
            print('=> Load Done')
        return model
Ejemplo n.º 5
0
class IntervalMLP(nn.Module):

    def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
        super(IntervalMLP, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.fc1 = LinearInterval(self.in_dim, hidden_dim, input_layer=True)
        self.fc2 = LinearInterval(hidden_dim, hidden_dim)
        # Subject to be replaced dependent on task
        self.last = LinearInterval(hidden_dim, out_dim)
        self.a = nn.Parameter(torch.zeros(3), requires_grad=True)
        self.e = None

        self.bounds = None

    def save_bounds(self, x):
        s = x.size(1) // 3
        self.bounds = x[:, s:2*s], x[:, 2*s:]

    def calc_eps(self, r):
        exp = self.a.exp()
        self.e = r * exp / exp.sum()

    def print_eps(self):
        e1 = self.fc1.eps.detach()
        e2 = self.fc2.eps.detach()
        print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}")
        print(f"sum: {e2.sum()} - mean: {e2.mean()} - std: {e2.std()}")
        # print(100 * "=")
        # print(e1)
        # print(100 * "+")
        # print(e2)
        # print(100 * "+")

        for name, layer in self.last.items():
            l = layer.eps.detach()
            print(f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}")
            # print(100 * "+")
            # print(l)
            # print(100 * "+")


    def reset_importance(self):
        self.fc1.rest_importance()
        self.fc2.rest_importance()
        for _, layer in self.last.items():
            layer.rest_importance()

    def set_eps(self, eps, trainable=False):
        if trainable:
            self.calc_eps(eps)
            self.fc1.calc_eps(self.e[0])
            self.fc2.calc_eps(self.e[1])
            for _, layer in self.last.items():
                layer.calc_eps(self.e[2])
        else:
            self.fc1.calc_eps(eps)
            self.fc2.calc_eps(eps)
            for _, layer in self.last.items():
                layer.calc_eps(eps)

    def features(self, x):
        x = x.view(-1, self.in_dim)
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        self.save_bounds(x)
        return x

    def logits(self, x):
        return self.last(x)

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return {k: v[:, :v.size(1)//3] for k, v in x.items()}
Ejemplo n.º 6
0
class Large(nn.Module):
    def __init__(self, eps=0):
        super().__init__()
        self.conv1 = Conv2dInterval(3, 64, 3, 1, input_layer=True)
        self.conv2 = Conv2dInterval(64, 64, 3, 1)
        self.conv3 = Conv2dInterval(64, 128, 3, 2)
        self.conv4 = Conv2dInterval(128, 128, 3, 1)
        self.conv5 = Conv2dInterval(128, 128, 3, 1)
        self.fc1 = LinearInterval(128 * 9 * 9, 200)
        self.last = LinearInterval(200, 10)

        self.a = nn.Parameter(torch.zeros(7), requires_grad=True)
        self.e = None

        self.eps = eps
        self.bounds = None

    def save_bounds(self, x):
        s = x.size(1) // 3
        self.bounds = x[:, s:2 * s], x[:, 2 * s:]

    def calc_eps(self, r):
        exp = self.a.exp()
        self.e = r * exp / exp.sum()

    def print_eps(self):
        for c in (self.conv1, self.conv2, self.conv3, self.conv4, self.conv5,
                  self.fc1):
            e1 = c.eps.detach()
            print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}")

        for name, layer in self.last.items():
            l = layer.eps.detach()
            print(
                f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}"
            )

    def reset_importance(self):
        pass
        # self.conv1.reset_importance()
        # self.conv2.reset_importance()
        # self.conv3.reset_importance()
        # self.conv4.reset_importance()
        # self.conv5.reset_importance()
        # self.fc1.reset_importance()
        # for _, layer in self.last.items():
        #     layer.reset_importance()

    def set_eps(self, eps, trainable=False):
        if trainable:
            self.calc_eps(eps)

            self.conv1.calc_eps(self.e[0])
            self.conv2.calc_eps(self.e[1])
            self.conv3.calc_eps(self.e[2])
            self.conv4.calc_eps(self.e[3])
            self.conv5.calc_eps(self.e[4])
            self.fc1.calc_eps(self.e[5])
            for _, layer in self.last.items():
                layer.calc_eps(self.e[6])
        else:
            self.conv1.calc_eps(eps)
            self.conv2.calc_eps(eps)
            self.conv3.calc_eps(eps)
            self.conv4.calc_eps(eps)
            self.conv5.calc_eps(eps)
            self.fc1.calc_eps(eps)
            for _, layer in self.last.items():
                layer.calc_eps(eps)

    def features(self, x):
        x = f.relu(self.conv1(x))
        x = f.relu(self.conv2(x))
        x = f.relu(self.conv3(x))
        x = f.relu(self.conv4(x))
        x = f.relu(self.conv5(x))
        x = torch.flatten(x, 1)
        x = f.relu(self.fc1(x))
        self.save_bounds(x)
        return x

    def logits(self, x):
        return self.last(x)

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return {k: v[:, :v.size(1) // 3] for k, v in x.items()}
Ejemplo n.º 7
0
class IntervalCNN(nn.Module):
    def __init__(self, in_channel=3, out_dim=10, pooling=MaxPool2dInterval):
        super(IntervalCNN, self).__init__()

        self.input = Conv2dInterval(in_channel,
                                    32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1,
                                    input_layer=True)
        self.c1 = nn.Sequential(
            Conv2dInterval(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25))
        self.c2 = nn.Sequential(
            Conv2dInterval(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25))
        self.c3 = nn.Sequential(
            Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=1), IntervalDropout(0.25))
        self.fc1 = nn.Sequential(LinearInterval(128 * 5 * 5, 256), nn.ReLU())
        self.last = LinearInterval(256, out_dim)
        self.a = nn.Parameter(torch.zeros(9), requires_grad=True)
        self.e = None

        self.bounds = None

    def save_bounds(self, x):
        s = x.size(1) // 3
        self.bounds = x[:, s:2 * s], x[:, 2 * s:]

    def print_eps(self):
        e = self.input.eps.detach()
        print(f"sum: {e.sum()} - mean: {e.mean()} - std: {e.std()}")
        print(f"min: {e.min()} - max: {e.max()}")

        for c in (self.c1, self.c2, self.c3):
            e1 = c[0].eps.detach()
            e2 = c[2].eps.detach()
            print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}")
            print(f"min: {e1.min()} - max: {e1.max()}")
            print(f"sum: {e2.sum()} - mean: {e2.mean()} - std: {e2.std()}")
            print(f"min: {e2.min()} - max: {e2.max()}")

        e = self.fc1[0].eps.detach()
        print(f"sum: {e.sum()} - mean: {e.mean()} - std: {e.std()}")
        print(f"min: {e.min()} - max: {e.max()}")

        for name, layer in self.last.items():
            l = layer.eps.detach()
            print(
                f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}"
            )
            print(f"min: {l.min()} - max: {l.max()}")

    def calc_eps(self, r):
        exp = self.a.exp()
        self.e = r * exp / exp.sum()

    def reset_importance(self):
        self.input.rest_importance()
        self.c1[0].rest_importance()
        self.c1[2].rest_importance()
        self.c2[0].rest_importance()
        self.c2[2].rest_importance()
        self.c3[0].rest_importance()
        self.c3[2].rest_importance()
        self.fc1[0].rest_importance()
        for _, layer in self.last.items():
            layer.rest_importance()

    def set_eps(self, eps, trainable=False):

        if trainable:
            self.calc_eps(eps)

            self.input.calc_eps(self.e[0])
            self.c1[0].calc_eps(self.e[1])
            self.c1[2].calc_eps(self.e[2])
            self.c2[0].calc_eps(self.e[3])
            self.c2[2].calc_eps(self.e[4])
            self.c3[0].calc_eps(self.e[5])
            self.c3[2].calc_eps(self.e[6])
            self.fc1[0].calc_eps(self.e[7])
            for _, layer in self.last.items():
                layer.calc_eps(self.e[8])
        else:
            self.input.calc_eps(eps)
            self.c1[0].calc_eps(eps)
            self.c1[2].calc_eps(eps)
            self.c2[0].calc_eps(eps)
            self.c2[2].calc_eps(eps)
            self.c3[0].calc_eps(eps)
            self.c3[2].calc_eps(eps)
            self.fc1[0].calc_eps(eps)
            for _, layer in self.last.items():
                layer.calc_eps(eps)

    def features(self, x):
        x = self.input(x)
        x = self.c1(x)
        x = self.c2(x)
        x = self.c3(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        self.save_bounds(x)
        return x

    def logits(self, x):
        return self.last(x)

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return {k: v[:, :v.size(1) // 3] for k, v in x.items()}