def __init__(self, num_classes=10, in_features=256, weights_init_path=None): super(AlexNet, self).__init__() self.num_classes = num_classes self.in_features = in_features self.num_channels = 3 self.image_size = 28 self.name = 'AlexNet' self.dropout_keep_prob = 0.5 self.setup_net() # if centroid: # self.centroid = nn.Embedding(num_cls, self.out_dim) if weights_init_path is not None: init_weights(self) self.load(weights_init_path) else: init_weights(self) if self.in_features != 0: self.fc9.weight.data.normal_(0, 0.005).clamp_(min=-0.01, max=0.01) self.fc9.bias.data.fill_(0.1) else: self.fc8.weight.data.normal_(0, 0.005).clamp_(min=-0.01, max=0.01) self.fc8.bias.data.fill_(0.1)
def __init__(self, in_features): super(DigitDiscriminator, self).__init__() self.in_features = in_features self.discriminator = nn.Sequential( nn.Linear(self.in_features, 500), nn.ReLU(), nn.Linear(500, 500), nn.ReLU(), nn.Linear(500, 1) ) init_weights(self)
def __init__(self, num_classes=10, weights_init_path=None, in_features=0): super(LeNet, self).__init__() self.num_classes = num_classes self.in_features = in_features self.num_channels = 1 self.image_size = 28 self.name = 'LeNet' self.setup_net() if weights_init_path is not None: init_weights(self) self.load(weights_init_path) else: init_weights(self)
def __init__(self, num_classes=10, weights_init_path=None, in_features=0, num_domains=2, mmd=False): super(DSBNLeNet, self).__init__() self.num_classes = num_classes self.in_features = in_features self.num_channels = 1 self.image_size = 28 self.num_domains = num_domains self.mmd = mmd self.name = 'DSBNLeNet' self.setup_net() if weights_init_path is not None: init_weights(self) self.load(weights_init_path) else: init_weights(self)
def load(self, init_path): net_init_dict = torch.load(init_path) init_weights(self) updated_state_dict = self.state_dict() print('load {} params.'.format(init_path)) for k, v in updated_state_dict.items(): if k in net_init_dict: if v.shape == net_init_dict[k].shape: updated_state_dict[k] = net_init_dict[k] else: print( "{0} params' shape not the same as pretrained params. Initialize with default settings.".format( k)) else: print("{0} params does not exist. Initialize with default settings.".format(k)) self.load_state_dict(updated_state_dict)
def __init__(self, in_features): super(CPUADiscriminator, self).__init__() self.in_features = in_features self.discriminator = nn.Sequential( nn.Linear(self.in_features, 1024), nn.ReLU(), nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1) ) init_weights(self) last_layer = self.discriminator[4] last_layer.weight.data.normal_(0, 0.3).clamp_(min=-0.6, max=0.6) last_layer.bias.data.zero_()
def __init__(self, num_classes, in_features): super(ProjectionDiscriminator, self).__init__() self.num_classes = num_classes self.in_features = in_features self.fc1 = nn.Linear(in_features, in_features) self.fc2 = nn.Linear(in_features, in_features) self.embed = nn.Embedding(num_classes, in_features) self.linear = nn.Linear(in_features, 1) # self.fc1 = spectral_norm(nn.Linear(in_features, in_features)) # self.fc2 = spectral_norm(nn.Linear(in_features, in_features)) # self.embed = spectral_norm(nn.Embedding(num_classes, in_features)) # self.linear = spectral_norm(nn.Linear(in_features, 1)) init_weights(self) last_layer = self.linear last_layer.weight.data.normal_(0, 0.3).clamp_(min=-0.6, max=0.6) last_layer.bias.data.zero_()