def solve(self): stop = False if self.resume: self.iters += 1 self.loop += 1 while True: # updating the target label hypothesis through clustering target_hypt = {} filtered_classes = [] with torch.no_grad(): # self.update_ss_alignment_loss_weight() print('Clustering based on %s...' % self.source_name) # 1-3 生成目标域的伪标签 self.update_labels() self.clustered_target_samples = self.clustering.samples target_centers = self.clustering.centers center_change = self.clustering.center_change path2label = self.clustering.path2label # updating the history self.register_history('target_centers', target_centers, self.opt.CLUSTERING.HISTORY_LEN) self.register_history('ts_center_dist', center_change, self.opt.CLUSTERING.HISTORY_LEN) self.register_history('target_labels', path2label, self.opt.CLUSTERING.HISTORY_LEN) if self.clustered_target_samples is not None and \ self.clustered_target_samples['gt'] is not None: preds = to_onehot(self.clustered_target_samples['label'], self.opt.DATASET.NUM_CLASSES) gts = self.clustered_target_samples['gt'] # 模型评估,mean_acc, accuracy res = self.model_eval(preds, gts) print('Clustering %s: %.4f' % (self.opt.EVAL_METRIC, res)) # check if meet the stop condition stop = self.complete_training() if stop: break # 4.过滤掉模糊的样本和类别,filtering the clustering results target_hypt, filtered_classes = self.filtering() # update dataloaders self.construct_categorical_dataloader(target_hypt, filtered_classes) # update train data setting self.compute_iters_per_loop(filtered_classes) # 5.k步更新网络参数,k-step update of network parameters through forward-backward process self.update_network(filtered_classes) self.loop += 1 print('Training Done!')
def feature_clustering(self, feature_extractor, loader): centers = None self.stop = False # self.collect_samples(net, loader) self.collect_samples(feature_extractor, loader) feature = self.samples['feature'] refs = to_cuda(torch.LongTensor(range(self.num_classes)).unsqueeze(1)) num_samples = feature.size(0) # 啥意思?? num_split = ceil(1.0 * num_samples / self.max_len) while True: self.clustering_stop(centers) if centers is not None: self.centers = centers if self.stop: break centers = 0 count = 0 start = 0 for N in range(num_split): cur_len = min(self.max_len, num_samples - start) cur_feature = feature.narrow(0, start, cur_len) dist2center, labels = self.assign_labels(cur_feature) labels_onehot = to_onehot(labels, self.num_classes) count += torch.sum(labels_onehot, dim=0) labels = labels.unsqueeze(0) mask = (labels == refs).unsqueeze(2).type( torch.cuda.FloatTensor) reshaped_feature = cur_feature.unsqueeze(0) # update centers centers += torch.sum(reshaped_feature * mask, dim=1) start += cur_len mask = (count.unsqueeze(1) > 0).type(torch.cuda.FloatTensor) centers = mask * centers + (1 - mask) * self.init_centers dist2center, labels = [], [] start = 0 count = 0 for N in range(num_split): cur_len = min(self.max_len, num_samples - start) cur_feature = feature.narrow(0, start, cur_len) cur_dist2center, cur_labels = self.assign_labels(cur_feature) labels_onehot = to_onehot(cur_labels, self.num_classes) count += torch.sum(labels_onehot, dim=0) dist2center += [cur_dist2center] labels += [cur_labels] start += cur_len self.samples['label'] = torch.cat(labels, dim=0) self.samples['dist2center'] = torch.cat(dist2center, dim=0) cluster2label = self.align_centers() # reorder the centers self.centers = self.centers[cluster2label, :] # re-label the data according to the index num_samples = len(self.samples['feature']) for k in range(num_samples): self.samples['label'][k] = cluster2label[self.samples['label'] [k]].item() self.center_change = torch.mean(self.Dist.get_dist(self.centers, \ self.init_centers)) for i in range(num_samples): self.path2label[self.samples['data'] [i]] = self.samples['label'][i].item() del self.samples['feature']