コード例 #1
0
    def __init__(self, loss, resnet_type="32", num_classes=10):
        super(SimpleLearner, self).__init__(loss,
                                            resnet_type=resnet_type,
                                            num_classes=num_classes)

        self.features_extractor = get_resnet(resnet_type)
        self.features_extractor.fc = nn.Sequential()
コード例 #2
0
    def __init__(self, loss, resnet_type="32", num_classes=10):
        super(LwF, self).__init__(loss,
                                  resnet_type=resnet_type,
                                  num_classes=num_classes)

        self.features_extractor = get_resnet(resnet_type)
        self.features_extractor.fc = nn.Sequential()

        self.previous_model = None
コード例 #3
0
    def init_erm_phase(self):

        if self.args.ctr_model_name == 'lenet':
            from models.lenet import LeNet5
            ctr_phi = LeNet5().to(self.cuda)
        if self.args.ctr_model_name == 'alexnet':
            from models.alexnet import alexnet
            ctr_phi = alexnet(self.args.out_classes, self.args.pre_trained,
                              'matchdg_ctr').to(self.cuda)
        if self.args.ctr_model_name == 'fc':
            from models.fc import FC
            fc_layer = 0
            ctr_phi = FC(self.args.out_classes, fc_layer).to(self.cuda)
        if 'resnet' in self.args.ctr_model_name:
            from models.resnet import get_resnet
            fc_layer = 0
            ctr_phi = get_resnet(self.args.ctr_model_name,
                                 self.args.out_classes, fc_layer,
                                 self.args.img_c, self.args.pre_trained,
                                 self.args.os_env).to(self.cuda)
        if 'densenet' in self.args.ctr_model_name:
            from models.densenet import get_densenet
            fc_layer = 0
            ctr_phi = get_densenet(self.args.ctr_model_name,
                                   self.args.out_classes, fc_layer,
                                   self.args.img_c, self.args.pre_trained,
                                   self.args.os_env).to(self.cuda)

        # Load MatchDG CTR phase model from the saved weights
        if self.args.os_env:
            base_res_dir = os.getenv(
                'PT_DATA_DIR'
            ) + '/' + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(
                self.args.train_domains)
        else:
            base_res_dir = "results/" + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(
                self.args.train_domains)
        save_path = base_res_dir + '/Model_' + self.ctr_load_post_string + '.pth'
        ctr_phi.load_state_dict(torch.load(save_path))
        ctr_phi.eval()

        #Inferred Match Case
        if self.args.match_case == -1:
            inferred_match = 1
            data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank = get_matched_pairs(
                self.args, self.cuda, self.train_dataset, self.domain_size,
                self.total_domains, self.training_list_size, ctr_phi,
                self.args.match_case, self.args.perfect_match, inferred_match)
        # x% percentage match initial strategy
        else:
            inferred_match = 0
            data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank = get_matched_pairs(
                self.args, self.cuda, self.train_dataset, self.domain_size,
                self.total_domains, self.training_list_size, ctr_phi,
                self.args.match_case, self.args.perfect_match, inferred_match)

        return data_match_tensor, label_match_tensor
コード例 #4
0
ファイル: base_eval.py プロジェクト: nijinjose/robustdg
    def get_model(self, run_matchdg_erm=0):

        if self.args.model_name == 'lenet':
            from models.lenet import LeNet5
            phi = LeNet5()

        if self.args.model_name == 'fc':
            from models.fc import FC
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer = 0
            else:
                fc_layer = self.args.fc_layer
            phi = FC(self.args.out_classes, fc_layer)

        if self.args.model_name == 'domain_bed_mnist':
            from models.domain_bed_mnist import DomainBed
            phi = DomainBed(self.args.img_c)

        if self.args.model_name == 'alexnet':
            from models.alexnet import alexnet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer = 0
            else:
                fc_layer = self.args.fc_layer
            phi = alexnet(self.args.model_name, self.args.out_classes,
                          fc_layer, self.args.img_c, self.args.pre_trained,
                          self.args.os_env)

        if 'resnet' in self.args.model_name:
            from models.resnet import get_resnet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer = 0
            else:
                fc_layer = self.args.fc_layer
            phi = get_resnet(self.args.model_name, self.args.out_classes,
                             fc_layer, self.args.img_c, self.args.pre_trained,
                             self.args.os_env)

        if 'densenet' in self.args.model_name:
            from models.densenet import get_densenet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer = 0
            else:
                fc_layer = self.args.fc_layer
            phi = get_densenet(self.args.model_name, self.args.out_classes,
                               fc_layer, self.args.img_c,
                               self.args.pre_trained, self.args.os_env)

        print('Model Architecture: ', self.args.model_name)

        self.phi = phi.to(self.cuda)
        self.load_model(run_matchdg_erm)

        return
コード例 #5
0
    def __init__(self, loss, resnet_type="32", num_classes=10):
        super(MultiTaskLearner, self).__init__()
        self.n_classes = num_classes
        self.n_known = 0

        self.loss = loss

        self.features_extractor = get_resnet(resnet_type)
        self.features_extractor.fc = nn.Sequential()
        self.classifier = nn.Linear(self.features_extractor.out_dim,
                                    num_classes)
コード例 #6
0
ファイル: base_eval.py プロジェクト: frezaeix/robustdg
 def get_model(self):
     
     if self.args.model_name == 'lenet':
         from models.lenet import LeNet5
         phi= LeNet5()
     if self.args.model_name == 'alexnet':
         from models.alexnet import alexnet
         phi= alexnet(self.args.out_classes, self.args.pre_trained, self.args.method_name)
     if self.args.model_name == 'resnet18':
         from models.resnet import get_resnet
         phi= get_resnet('resnet18', self.args.out_classes, self.args.method_name, 
                         self.args.img_c, self.args.pre_trained)
     
     print('Model Architecture: ', self.args.model_name)
     phi= phi.to(self.cuda)
     return phi
コード例 #7
0
    def get_model(self, run_matchdg_erm=0):

        if self.args.model_name == 'lenet':
            from models.lenet import LeNet5
            phi = LeNet5()
        if self.args.model_name == 'alexnet':
            from models.alexnet import alexnet
            phi = alexnet(self.args.out_classes, self.args.pre_trained,
                          self.args.method_name)
        if self.args.model_name == 'domain_bed_mnist':
            from models.domain_bed_mnist import DomainBed
            phi = DomainBed(self.args.img_c)
        if 'resnet' in self.args.model_name:
            from models.resnet import get_resnet
            phi = get_resnet(self.args.model_name, self.args.out_classes,
                             self.args.method_name, self.args.img_c,
                             self.args.pre_trained)

        print('Model Architecture: ', self.args.model_name)

        self.phi = phi.to(self.cuda)
        self.load_model(run_matchdg_erm)

        return
コード例 #8
0
    def get_model(self):

        if self.args.model_name == 'lenet':
            from models.lenet import LeNet5
            phi = LeNet5()
        if self.args.model_name == 'alexnet':
            from models.alexnet import alexnet
            phi = alexnet(self.args.out_classes, self.args.pre_trained,
                          self.args.method_name)
        if self.args.model_name == 'domain_bed_mnist':
            from models.domain_bed_mnist import DomainBed
            phi = DomainBed(self.args.img_c)
        if 'resnet' in self.args.model_name:
            from models.resnet import get_resnet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer = 0
            else:
                fc_layer = self.args.fc_layer
            phi = get_resnet(self.args.model_name, self.args.out_classes,
                             fc_layer, self.args.img_c, self.args.pre_trained)

        print('Model Architecture: ', self.args.model_name)
        phi = phi.to(self.cuda)
        return phi