def get_match_function(self, epoch): perfect_match= self.args.perfect_match #Start initially with randomly defined batch; else find the local approximate batch if epoch > 0: inferred_match=1 if self.args.match_flag: data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match ) else: temp_1, temp_2, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match ) else: inferred_match=0 data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match ) return data_match_tensor, label_match_tensor
def get_metric_eval(self): inferred_match = 1 data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank = get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, inferred_match) score = perfect_match_score(indices_matched) perfect_match_rank = np.array(perfect_match_rank) self.metric_score['Perfect Match Score'] = score self.metric_score['TopK Perfect Match Score'] = 100 * np.sum( perfect_match_rank < self.top_k) / perfect_match_rank.shape[0] self.metric_score['Perfect Match Rank'] = np.mean(perfect_match_rank) print('Perfect Match Score: ', self.metric_score['Perfect Match Score']) print('TopK Perfect Match Score: ', self.metric_score['TopK Perfect Match Score']) print('Perfect Match Rank: ', self.metric_score['Perfect Match Rank']) return
def init_erm_phase(self): if self.args.ctr_model_name == 'lenet': from models.lenet import LeNet5 ctr_phi = LeNet5().to(self.cuda) if self.args.ctr_model_name == 'alexnet': from models.alexnet import alexnet ctr_phi = alexnet(self.args.out_classes, self.args.pre_trained, 'matchdg_ctr').to(self.cuda) if self.args.ctr_model_name == 'resnet18': from models.resnet import get_resnet ctr_phi = get_resnet('resnet18', self.args.out_classes, 'matchdg_ctr', self.args.img_c, self.args.pre_trained).to(self.cuda) # Load MatchDG CTR phase model from the saved weights base_res_dir = "results/" + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str( self.args.train_domains) + '_test_' + str(self.args.test_domains) save_path = base_res_dir + '/Model_' + self.ctr_load_post_string + '.pth' ctr_phi.load_state_dict(torch.load(save_path)) ctr_phi.eval() #Inferred Match Case if self.args.match_case == -1: inferred_match = 1 data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank = get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, ctr_phi, self.args.match_case, inferred_match) # x% percentage match initial strategy else: inferred_match = 0 data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank = get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, ctr_phi, self.args.match_case, inferred_match) return data_match_tensor, label_match_tensor
def get_metric_eval(self): if self.args.match_func_data_case == 'train': dataset = self.train_dataset['data_loader'] total_domains = self.train_dataset['total_domains'] domain_list = self.train_dataset['domain_list'] base_domain_size = self.train_dataset['base_domain_size'] domain_size_list = self.train_dataset['domain_size_list'] elif self.args.match_func_data_case == 'val': dataset = self.val_dataset['data_loader'] total_domains = self.val_dataset['total_domains'] domain_list = self.val_dataset['domain_list'] base_domain_size = self.val_dataset['base_domain_size'] domain_size_list = self.val_dataset['domain_size_list'] elif self.args.match_func_data_case == 'test': dataset = self.test_dataset['data_loader'] total_domains = self.test_dataset['total_domains'] domain_list = self.test_dataset['domain_list'] base_domain_size = self.test_dataset['base_domain_size'] domain_size_list = self.test_dataset['domain_size_list'] inferred_match = 1 # Self Augmentation Match Function evaluation will always follow perfect matches if self.args.match_func_aug_case: perfect_match = 1 else: perfect_match = self.args.perfect_match data_match_tensor, label_match_tensor, indices_matched, perfect_match_rank = get_matched_pairs( self.args, self.cuda, dataset, base_domain_size, total_domains, domain_size_list, self.phi, self.args.match_case, perfect_match, inferred_match) score = perfect_match_score(indices_matched) perfect_match_rank = np.array(perfect_match_rank) self.metric_score['Perfect Match Score'] = score self.metric_score['TopK Perfect Match Score'] = 100 * np.sum( perfect_match_rank < self.args.top_k) / perfect_match_rank.shape[0] self.metric_score['Perfect Match Rank'] = np.mean(perfect_match_rank) # Perfect Match Prediction Discrepancy # perm = torch.randperm(data_match_tensor.size(0)) # data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) # label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) # total_batches= len(data_match_tensor_split) # penalty_ws= 0.0 # for batch_idx in range(total_batches): # curr_batch_size= data_match_tensor_split[batch_idx].shape[0] # data_match= data_match_tensor_split[batch_idx].to(self.cuda) # data_match= data_match.view( data_match.shape[0]*data_match.shape[1], data_match.shape[2], data_match.shape[3], data_match.shape[4] ) # feat_match= self.phi( data_match ) # label_match= label_match_tensor_split[batch_idx].to(self.cuda) # label_match= label_match.view( label_match.shape[0]*label_match.shape[1] ) # # Creating tensor of shape ( domain size, total domains, feat size ) # if len(feat_match.shape) == 4: # feat_match= feat_match.view( curr_batch_size, len(total_domains), feat_match.shape[1]*feat_match.shape[2]*feat_match.shape[3] ) # else: # feat_match= feat_match.view( curr_batch_size, len(total_domains), feat_match.shape[1] ) # label_match= label_match.view( curr_batch_size, len(total_domains) ) # # print(feat_match.shape) # data_match= data_match.view( curr_batch_size, len(total_domains), data_match.shape[1], data_match.shape[2], data_match.shape[3] ) # #Positive Match Loss # pos_match_counter=0 # for d_i in range(feat_match.shape[1]): # # if d_i != base_domain_idx: # # continue # for d_j in range(feat_match.shape[1]): # if d_j > d_i: # if self.args.pos_metric == 'l2': # wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) ) # elif self.args.pos_metric == 'l1': # wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) ) # elif self.args.pos_metric == 'cos': # wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) ) # pos_match_counter += feat_match.shape[0] # wasserstein_loss = wasserstein_loss / pos_match_counter # penalty_ws+= float(wasserstein_loss) # self.metric_score['Perfect Match Distance']= penalty_ws print('Perfect Match Score: ', self.metric_score['Perfect Match Score']) print('TopK Perfect Match Score: ', self.metric_score['TopK Perfect Match Score']) print('Perfect Match Rank: ', self.metric_score['Perfect Match Rank']) # print('Perfect Match Distance: ', self.metric_score['Perfect Match Distance']) return