コード例 #1
0
    def __init__(self,
                 num_classes=10,
                 in_features=256,
                 weights_init_path=None):
        super(AlexNet, self).__init__()
        self.num_classes = num_classes
        self.in_features = in_features
        self.num_channels = 3
        self.image_size = 28
        self.name = 'AlexNet'
        self.dropout_keep_prob = 0.5

        self.setup_net()
        # if centroid:
        #     self.centroid = nn.Embedding(num_cls, self.out_dim)
        if weights_init_path is not None:
            init_weights(self)
            self.load(weights_init_path)
        else:
            init_weights(self)

        if self.in_features != 0:
            self.fc9.weight.data.normal_(0, 0.005).clamp_(min=-0.01, max=0.01)
            self.fc9.bias.data.fill_(0.1)
        else:
            self.fc8.weight.data.normal_(0, 0.005).clamp_(min=-0.01, max=0.01)
            self.fc8.bias.data.fill_(0.1)
コード例 #2
0
    def __init__(self, in_features):
        super(DigitDiscriminator, self).__init__()
        self.in_features = in_features

        self.discriminator = nn.Sequential(
            nn.Linear(self.in_features, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 1)
        )

        init_weights(self)
コード例 #3
0
    def __init__(self, num_classes=10, weights_init_path=None, in_features=0):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.in_features = in_features
        self.num_channels = 1
        self.image_size = 28
        self.name = 'LeNet'
        self.setup_net()

        if weights_init_path is not None:
            init_weights(self)
            self.load(weights_init_path)
        else:
            init_weights(self)
コード例 #4
0
    def __init__(self, num_classes=10, weights_init_path=None, in_features=0, num_domains=2, mmd=False):
        super(DSBNLeNet, self).__init__()
        self.num_classes = num_classes
        self.in_features = in_features
        self.num_channels = 1
        self.image_size = 28
        self.num_domains = num_domains
        self.mmd = mmd
        self.name = 'DSBNLeNet'
        self.setup_net()

        if weights_init_path is not None:
            init_weights(self)
            self.load(weights_init_path)
        else:
            init_weights(self)
コード例 #5
0
 def load(self, init_path):
     net_init_dict = torch.load(init_path)
     init_weights(self)
     updated_state_dict = self.state_dict()
     print('load {} params.'.format(init_path))
     for k, v in updated_state_dict.items():
         if k in net_init_dict:
             if v.shape == net_init_dict[k].shape:
                 updated_state_dict[k] = net_init_dict[k]
             else:
                 print(
                     "{0} params' shape not the same as pretrained params. Initialize with default settings.".format(
                         k))
         else:
             print("{0} params does not exist. Initialize with default settings.".format(k))
     self.load_state_dict(updated_state_dict)
コード例 #6
0
    def __init__(self, in_features):
        super(CPUADiscriminator, self).__init__()
        self.in_features = in_features

        self.discriminator = nn.Sequential(
            nn.Linear(self.in_features, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1)
        )

        init_weights(self)
        last_layer = self.discriminator[4]
        last_layer.weight.data.normal_(0, 0.3).clamp_(min=-0.6, max=0.6)
        last_layer.bias.data.zero_()
コード例 #7
0
    def __init__(self, num_classes, in_features):
        super(ProjectionDiscriminator, self).__init__()
        self.num_classes = num_classes
        self.in_features = in_features

        self.fc1 = nn.Linear(in_features, in_features)
        self.fc2 = nn.Linear(in_features, in_features)
        self.embed = nn.Embedding(num_classes, in_features)
        self.linear = nn.Linear(in_features, 1)
        # self.fc1 = spectral_norm(nn.Linear(in_features, in_features))
        # self.fc2 = spectral_norm(nn.Linear(in_features, in_features))
        # self.embed = spectral_norm(nn.Embedding(num_classes, in_features))
        # self.linear = spectral_norm(nn.Linear(in_features, 1))

        init_weights(self)
        last_layer = self.linear
        last_layer.weight.data.normal_(0, 0.3).clamp_(min=-0.6, max=0.6)
        last_layer.bias.data.zero_()