Ejemplo n.º 1
0
class Large(nn.Module):
    def __init__(self, eps=0):
        super().__init__()
        self.conv1 = Conv2dInterval(3, 64, 3, 1, input_layer=True)
        self.conv2 = Conv2dInterval(64, 64, 3, 1)
        self.conv3 = Conv2dInterval(64, 128, 3, 2)
        self.conv4 = Conv2dInterval(128, 128, 3, 1)
        self.conv5 = Conv2dInterval(128, 128, 3, 1)
        self.fc1 = LinearInterval(128 * 9 * 9, 200)
        self.last = LinearInterval(200, 10)

        self.a = nn.Parameter(torch.zeros(7), requires_grad=True)
        self.e = None

        self.eps = eps
        self.bounds = None

    def save_bounds(self, x):
        s = x.size(1) // 3
        self.bounds = x[:, s:2 * s], x[:, 2 * s:]

    def calc_eps(self, r):
        exp = self.a.exp()
        self.e = r * exp / exp.sum()

    def print_eps(self):
        for c in (self.conv1, self.conv2, self.conv3, self.conv4, self.conv5,
                  self.fc1):
            e1 = c.eps.detach()
            print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}")

        for name, layer in self.last.items():
            l = layer.eps.detach()
            print(
                f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}"
            )

    def reset_importance(self):
        pass
        # self.conv1.reset_importance()
        # self.conv2.reset_importance()
        # self.conv3.reset_importance()
        # self.conv4.reset_importance()
        # self.conv5.reset_importance()
        # self.fc1.reset_importance()
        # for _, layer in self.last.items():
        #     layer.reset_importance()

    def set_eps(self, eps, trainable=False):
        if trainable:
            self.calc_eps(eps)

            self.conv1.calc_eps(self.e[0])
            self.conv2.calc_eps(self.e[1])
            self.conv3.calc_eps(self.e[2])
            self.conv4.calc_eps(self.e[3])
            self.conv5.calc_eps(self.e[4])
            self.fc1.calc_eps(self.e[5])
            for _, layer in self.last.items():
                layer.calc_eps(self.e[6])
        else:
            self.conv1.calc_eps(eps)
            self.conv2.calc_eps(eps)
            self.conv3.calc_eps(eps)
            self.conv4.calc_eps(eps)
            self.conv5.calc_eps(eps)
            self.fc1.calc_eps(eps)
            for _, layer in self.last.items():
                layer.calc_eps(eps)

    def features(self, x):
        x = f.relu(self.conv1(x))
        x = f.relu(self.conv2(x))
        x = f.relu(self.conv3(x))
        x = f.relu(self.conv4(x))
        x = f.relu(self.conv5(x))
        x = torch.flatten(x, 1)
        x = f.relu(self.fc1(x))
        self.save_bounds(x)
        return x

    def logits(self, x):
        return self.last(x)

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return {k: v[:, :v.size(1) // 3] for k, v in x.items()}
Ejemplo n.º 2
0
class IntervalMLP(nn.Module):

    def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
        super(IntervalMLP, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.fc1 = LinearInterval(self.in_dim, hidden_dim, input_layer=True)
        self.fc2 = LinearInterval(hidden_dim, hidden_dim)
        # Subject to be replaced dependent on task
        self.last = LinearInterval(hidden_dim, out_dim)
        self.a = nn.Parameter(torch.zeros(3), requires_grad=True)
        self.e = None

        self.bounds = None

    def save_bounds(self, x):
        s = x.size(1) // 3
        self.bounds = x[:, s:2*s], x[:, 2*s:]

    def calc_eps(self, r):
        exp = self.a.exp()
        self.e = r * exp / exp.sum()

    def print_eps(self):
        e1 = self.fc1.eps.detach()
        e2 = self.fc2.eps.detach()
        print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}")
        print(f"sum: {e2.sum()} - mean: {e2.mean()} - std: {e2.std()}")
        # print(100 * "=")
        # print(e1)
        # print(100 * "+")
        # print(e2)
        # print(100 * "+")

        for name, layer in self.last.items():
            l = layer.eps.detach()
            print(f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}")
            # print(100 * "+")
            # print(l)
            # print(100 * "+")


    def reset_importance(self):
        self.fc1.rest_importance()
        self.fc2.rest_importance()
        for _, layer in self.last.items():
            layer.rest_importance()

    def set_eps(self, eps, trainable=False):
        if trainable:
            self.calc_eps(eps)
            self.fc1.calc_eps(self.e[0])
            self.fc2.calc_eps(self.e[1])
            for _, layer in self.last.items():
                layer.calc_eps(self.e[2])
        else:
            self.fc1.calc_eps(eps)
            self.fc2.calc_eps(eps)
            for _, layer in self.last.items():
                layer.calc_eps(eps)

    def features(self, x):
        x = x.view(-1, self.in_dim)
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        self.save_bounds(x)
        return x

    def logits(self, x):
        return self.last(x)

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return {k: v[:, :v.size(1)//3] for k, v in x.items()}
Ejemplo n.º 3
0
class IntervalCNN(nn.Module):
    def __init__(self, in_channel=3, out_dim=10, pooling=MaxPool2dInterval):
        super(IntervalCNN, self).__init__()

        self.input = Conv2dInterval(in_channel,
                                    32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1,
                                    input_layer=True)
        self.c1 = nn.Sequential(
            Conv2dInterval(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25))
        self.c2 = nn.Sequential(
            Conv2dInterval(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=0), IntervalDropout(0.25))
        self.c3 = nn.Sequential(
            Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            Conv2dInterval(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(), pooling(2, stride=2, padding=1), IntervalDropout(0.25))
        self.fc1 = nn.Sequential(LinearInterval(128 * 5 * 5, 256), nn.ReLU())
        self.last = LinearInterval(256, out_dim)
        self.a = nn.Parameter(torch.zeros(9), requires_grad=True)
        self.e = None

        self.bounds = None

    def save_bounds(self, x):
        s = x.size(1) // 3
        self.bounds = x[:, s:2 * s], x[:, 2 * s:]

    def print_eps(self):
        e = self.input.eps.detach()
        print(f"sum: {e.sum()} - mean: {e.mean()} - std: {e.std()}")
        print(f"min: {e.min()} - max: {e.max()}")

        for c in (self.c1, self.c2, self.c3):
            e1 = c[0].eps.detach()
            e2 = c[2].eps.detach()
            print(f"sum: {e1.sum()} - mean: {e1.mean()} - std: {e1.std()}")
            print(f"min: {e1.min()} - max: {e1.max()}")
            print(f"sum: {e2.sum()} - mean: {e2.mean()} - std: {e2.std()}")
            print(f"min: {e2.min()} - max: {e2.max()}")

        e = self.fc1[0].eps.detach()
        print(f"sum: {e.sum()} - mean: {e.mean()} - std: {e.std()}")
        print(f"min: {e.min()} - max: {e.max()}")

        for name, layer in self.last.items():
            l = layer.eps.detach()
            print(
                f"last-{name} sum: {l.sum()} - mean: {l.mean()} - std: {l.std()}"
            )
            print(f"min: {l.min()} - max: {l.max()}")

    def calc_eps(self, r):
        exp = self.a.exp()
        self.e = r * exp / exp.sum()

    def reset_importance(self):
        self.input.rest_importance()
        self.c1[0].rest_importance()
        self.c1[2].rest_importance()
        self.c2[0].rest_importance()
        self.c2[2].rest_importance()
        self.c3[0].rest_importance()
        self.c3[2].rest_importance()
        self.fc1[0].rest_importance()
        for _, layer in self.last.items():
            layer.rest_importance()

    def set_eps(self, eps, trainable=False):

        if trainable:
            self.calc_eps(eps)

            self.input.calc_eps(self.e[0])
            self.c1[0].calc_eps(self.e[1])
            self.c1[2].calc_eps(self.e[2])
            self.c2[0].calc_eps(self.e[3])
            self.c2[2].calc_eps(self.e[4])
            self.c3[0].calc_eps(self.e[5])
            self.c3[2].calc_eps(self.e[6])
            self.fc1[0].calc_eps(self.e[7])
            for _, layer in self.last.items():
                layer.calc_eps(self.e[8])
        else:
            self.input.calc_eps(eps)
            self.c1[0].calc_eps(eps)
            self.c1[2].calc_eps(eps)
            self.c2[0].calc_eps(eps)
            self.c2[2].calc_eps(eps)
            self.c3[0].calc_eps(eps)
            self.c3[2].calc_eps(eps)
            self.fc1[0].calc_eps(eps)
            for _, layer in self.last.items():
                layer.calc_eps(eps)

    def features(self, x):
        x = self.input(x)
        x = self.c1(x)
        x = self.c2(x)
        x = self.c3(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        self.save_bounds(x)
        return x

    def logits(self, x):
        return self.last(x)

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return {k: v[:, :v.size(1) // 3] for k, v in x.items()}