def __init__(self, input_shape, num_classes, num_domains, hparams, conditional, class_balance): super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains, hparams) self.register_buffer('update_count', torch.tensor([0])) self.conditional = conditional self.class_balance = class_balance # Algorithms self.featurizer = networks.Featurizer(input_shape, self.hparams) self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes) self.discriminator = networks.MLP(self.featurizer.n_outputs, num_domains, self.hparams) self.class_embeddings = nn.Embedding(num_classes, self.featurizer.n_outputs) # Optimizers self.disc_opt = torch.optim.Adam( (list(self.discriminator.parameters()) + list(self.class_embeddings.parameters())), lr=self.hparams["lr_d"], weight_decay=self.hparams['weight_decay_d'], betas=(self.hparams['beta1'], 0.9)) self.gen_opt = torch.optim.Adam( (list(self.featurizer.parameters()) + list(self.classifier.parameters())), lr=self.hparams["lr_g"], weight_decay=self.hparams['weight_decay_g'], betas=(self.hparams['beta1'], 0.9))
def __init__(self, input_shape, num_classes, num_domains, hparams): super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams) self.featurizer = networks.Featurizer(input_shape, self.hparams) self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes) self.network = nn.Sequential(self.featurizer, self.classifier) self.optimizer = torch.optim.Adam( self.network.parameters(), lr=self.hparams["lr"], weight_decay=self.hparams['weight_decay'])
def test_featurizer(self, dataset_name): """Test that Featurizer() returns a module which can take a correctly-sized input and return a correctly-sized output.""" batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) input_ = helpers.make_minibatches(dataset, batch_size)[0][0] input_shape = dataset.input_shape algorithm = networks.Featurizer(input_shape, hparams).cuda() output = algorithm(input_) self.assertEqual(list(output.shape), [batch_size, algorithm.n_outputs])
def __init__(self, input_shape, num_classes, num_domains, hparams): super(SagNet, self).__init__(input_shape, num_classes, num_domains, hparams) # featurizer network self.network_f = networks.Featurizer(input_shape, self.hparams) # content network self.network_c = networks.Classifier( self.network_f.n_outputs, num_classes, self.hparams['nonlinear_classifier']) # style network self.network_s = networks.Classifier( self.network_f.n_outputs, num_classes, self.hparams['nonlinear_classifier']) # # This commented block of code implements something closer to the # # original paper, but is specific to ResNet and puts in disadvantage # # the other algorithms. # resnet_c = networks.Featurizer(input_shape, self.hparams) # resnet_s = networks.Featurizer(input_shape, self.hparams) # # featurizer network # self.network_f = torch.nn.Sequential( # resnet_c.network.conv1, # resnet_c.network.bn1, # resnet_c.network.relu, # resnet_c.network.maxpool, # resnet_c.network.layer1, # resnet_c.network.layer2, # resnet_c.network.layer3) # # content network # self.network_c = torch.nn.Sequential( # resnet_c.network.layer4, # resnet_c.network.avgpool, # networks.Flatten(), # resnet_c.network.fc) # # style network # self.network_s = torch.nn.Sequential( # resnet_s.network.layer4, # resnet_s.network.avgpool, # networks.Flatten(), # resnet_s.network.fc) def opt(p): return torch.optim.Adam(p, lr=hparams["lr"], weight_decay=hparams["weight_decay"]) self.optimizer_f = opt(self.network_f.parameters()) self.optimizer_c = opt(self.network_c.parameters()) self.optimizer_s = opt(self.network_s.parameters()) self.weight_adv = hparams["sag_w_adv"]
def __init__(self, input_shape, num_classes, num_domains, hparams): super(MTL, self).__init__(input_shape, num_classes, num_domains, hparams) self.featurizer = networks.Featurizer(input_shape, self.hparams) self.classifier = nn.Linear(self.featurizer.n_outputs * 2, num_classes) self.optimizer = torch.optim.Adam( list(self.featurizer.parameters()) +\ list(self.classifier.parameters()), lr=self.hparams["lr"], weight_decay=self.hparams['weight_decay'] ) self.register_buffer( 'embeddings', torch.zeros(num_domains, self.featurizer.n_outputs)) self.ema = self.hparams['mtl_ema']