Пример #1
0
    def forward(self,
                x,
                sample=None,
                label=None,
                device='cpu',
                return_rate=False):
        rate = []
        temp = []
        Channel = []
        bsz = x.shape[0]
        temp.append(sample)

        if return_rate:
            for layer in self.layer:
                x = layer(x)
                temp.append(x.detach())
                rate.append(RD_fn(T=x, X=sample, Label=label, device=device))

            x = self.avgpool(x)
            x = self.classifier(x.view(bsz, -1))

            for i in range(len(temp) - 1):
                Channel.append(mi(temp[i], temp[i + 1], device, 'cov'))

            return x, torch.stack(rate), torch.stack(Channel)

        else:
            for layer in self.layer:
                x = layer(x)

        x = self.classifier(self.avgpool(x).view(bsz, -1))

        return x
Пример #2
0
    def forward(self,
                x,
                sample=None,
                label=None,
                device='cpu',
                return_rate=False):
        # See note [TorchScript super()]
        rate = []
        temp = []
        Channel = []
        temp.append(sample)

        x = self.conv1(x)

        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        if return_rate:
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer1(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer2(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer3(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer4(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))

            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

            for i in range(len(temp) - 1):
                Channel.append(mi(temp[i], temp[i + 1], device, 'cov'))

            return x, torch.stack(rate), torch.stack(Channel)
        else:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)

            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x
Пример #3
0
    def forward(self,
                x,
                sample=None,
                label=None,
                device='cpu',
                return_rate=False):

        rate = []
        temp = []
        Channel = []
        temp.append(sample)
        if return_rate:
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = x.view(-1, self.num_flat_features(x))
            x = F.relu(self.fc1(x))
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = F.relu(self.fc2(x))
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.fc3(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))

            for i in range(len(temp) - 1):
                Channel.append(mi(temp[i], temp[i + 1], device, 'cov'))

            return x, torch.stack(rate), torch.stack(Channel)
        else:
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
            x = x.view(-1, self.num_flat_features(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)

        return x
Пример #4
0
    def forward(self, x, sample=None, label=None, device='cpu', return_rate=False):
        rate = []
        temp = []
        Channel = []
        temp.append(sample)
        if return_rate:
            x = self.layer1(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer2(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer3(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer4(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))
            x = self.layer5(x)
            temp.append(x.detach())
            rate.append(RD_fn(X=sample, T=x, Label=label, device=device))

            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.classifier(x)

            for i in range(len(temp)-1):
                Channel.append(mi(temp[i], temp[i+1],device,'cov'))

            return x, torch.stack(rate), torch.stack(Channel)
        else:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.layer5(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.classifier(x)


        return x
Пример #5
0
    def forward(self,
                x,
                sample=None,
                label=None,
                device='cpu',
                return_rate=False):
        rate = []
        m = x.shape[0]
        x = x.view(-1, 784)  # Flattern the (n,3,32,32) to (n,3096)
        temp = []
        Channel = []
        temp.append(sample)

        if return_rate:
            x = self.gate(self.l1(x))
            temp.append(x.detach())
            rate.append(RD_fn(T=x, X=sample, Label=label, device=device))
            x = self.gate(self.l2(x))
            temp.append(x.detach())
            rate.append(RD_fn(T=x, X=sample, Label=label, device=device))
            x = self.gate(self.l3(x))
            temp.append(x.detach())
            rate.append(RD_fn(T=x, X=sample, Label=label, device=device))
            x = self.gate(self.l4(x))
            temp.append(x.detach())
            rate.append(RD_fn(T=x, X=sample, Label=label, device=device))
            x = self.l5(x)
            temp.append(x.detach())
            rate.append(RD_fn(T=x, X=sample, Label=label, device=device))

            for i in range(len(temp) - 1):
                Channel.append(mi(temp[i], temp[i + 1], device, 'cov'))

            return x, torch.stack(rate), torch.stack(Channel)
        else:
            x = self.gate(self.l1(x))
            x = self.gate(self.l2(x))
            x = self.gate(self.l3(x))
            x = self.gate(self.l4(x))

        return self.l5(x)
Пример #6
0
def mi_from_experiments(experiments):
    cols = transpose([map(lambda x:x[1],concat(row)) for row in experiments])
    plt.imshow([[mi(col1,col2,correct=False) for col1 in cols] for col2 in (cols)],
               interpolation='none')
    plt.colorbar()
    plt.show()