def test(self): self.net.eval() labels = np.arange(self.opt.DATASET.NUM_CLASSES) preds = [] gts = [] cm = [] for sample in iter(self.test_data['loader']): data, gt = to_cuda(sample['Img']), to_cuda(sample['Label']) logits = self.net(data)['logits'] pred = torch.max(logits, dim=1).indices preds += [logits] gts += [gt] try: cm += confusion_matrix(gt.cpu(), pred.cpu(), labels) except: cm = confusion_matrix(gt.cpu(), pred.cpu(), labels) print('==================') print('Confusion matrix:') print(cm) print('==================') preds = torch.cat(preds, dim=0) gts = torch.cat(gts, dim=0) res = self.model_eval(preds, gts) return res
def update_network(self): # initial configuration stop = False update_iters = 0 self.train_data[self.source_name]['iterator'] = iter( self.train_data[self.source_name]['loader']) while not stop: loss = 0 # update learning rate self.update_lr() # set the status of network self.net.train() self.net.zero_grad() # coventional sampling for training on labeled source data source_sample = self.get_samples(self.train_domain) source_data, source_gt = source_sample['Img'],\ source_sample['Label'] source_data = to_cuda(source_data) source_gt = to_cuda(source_gt) self.net.module.set_bn_domain() source_preds = self.net(source_data)['logits'] # compute the cross-entropy loss ce_loss = self.CELoss(source_preds, source_gt) ce_loss.backward() loss += ce_loss # update the network self.optimizer.step() if self.opt.TRAIN.LOGGING and (update_iters+1) % \ (max(1, self.iters_per_loop // 10)) == 0: accu = self.model_eval(source_preds, source_gt) cur_loss = {'ce_loss': ce_loss} self.logging(cur_loss, accu) if self.opt.TRAIN.TEST_INTERVAL > 0 and \ (self.iters+1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_loop) == 0: with torch.no_grad(): self.net.module.set_bn_domain() accu = self.test() print('Test at (loop %d, iters %d) with %s: %.4f.' % (self.loop, self.iters, self.opt.EVAL_METRIC, accu)) if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \ (self.iters+1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_loop) == 0: self.save_ckpt() update_iters += 1 self.iters += 1 # update stop condition if update_iters >= self.iters_per_loop: stop = True else: stop = False
def D_step(self, x_S, x_T): self.set_domain_id(0) preds_D_S = self.net_D(x_S.detach()) self.set_domain_id(1) preds_D_T = self.net_D(x_T.detach()) preds_D = torch.cat((preds_D_S, preds_D_T), dim=0) gt_D_S = to_cuda(torch.FloatTensor(preds_D_S.size()).fill_(1.0)) gt_D_T = to_cuda(torch.FloatTensor(preds_D_T.size()).fill_(0.0)) gt_D = torch.cat((gt_D_S, gt_D_T), dim=0) loss_D = self.BCELoss(preds_D, gt_D) return loss_D
def test(self): self.net.eval() preds = [] gts = [] for sample in iter(self.test_data['loader']): data, gt = to_cuda(sample['Img']), to_cuda(sample['Label']) logits = self.net(data)['logits'] preds += [logits] gts += [gt] preds = torch.cat(preds, dim=0) gts = torch.cat(gts, dim=0) res = self.model_eval(preds, gts) return res
def test(self): vout_all, vlabel_all = [], [] self.net.eval() for sample in iter(self.test_data['loader']): _, _, vclip, vlabel = sample vclip = to_cuda(vclip) vlabel = to_cuda(vlabel) vout = self.net(vclip) vout_all += [vout] vlabel_all += [vlabel] vout_all = torch.cat(vout_all, dim=0) vlabel_all = torch.cat(vlabel_all, dim=0) iou, iou_pos, accu = self.model_eval(vout_all, vlabel_all) return iou, iou_pos, accu
def patch_mean(self, nums_row, nums_col, dist, domain_probs_expand): assert (len(nums_row) == len(nums_col)) num_classes = len(nums_row) num_domains = dist.size()[2] mean_tensor = to_cuda( torch.zeros([num_classes, num_classes, num_domains])) row_start = row_end = 0 for row in range(num_classes): row_start = row_end row_end = row_start + nums_row[row] col_start = col_end = 0 for col in range(num_classes): col_start = col_end col_end = col_start + nums_col[col] num = torch.sum( dist.narrow(0, row_start, nums_row[row]).narrow(1, col_start, nums_col[col]), [0, 1]) den = torch.sum( domain_probs_expand.narrow(0, row_start, nums_row[row]).narrow( 1, col_start, nums_col[col]), [0, 1]) mean_tensor[row, col] = num / den return mean_tensor
def forward(self, source, target, nums_S, nums_T): assert(len(nums_S) == len(nums_T)), \ "The number of classes for source (%d) and target (%d) should be the same." \ % (len(nums_S), len(nums_T)) num_classes = len(nums_S) # compute the dist dist_layers = [] gamma_layers = [] for i in range(self.num_layers): cur_source = source[i] cur_target = target[i] dist = {} dist['ss'] = self.compute_paired_dist(cur_source, cur_source) dist['tt'] = self.compute_paired_dist(cur_target, cur_target) dist['st'] = self.compute_paired_dist(cur_source, cur_target) dist['ss'] = self.split_classwise(dist['ss'], nums_S) dist['tt'] = self.split_classwise(dist['tt'], nums_T) dist_layers += [dist] gamma_layers += [self.patch_gamma_estimation(nums_S, nums_T, dist)] # compute the kernel dist for i in range(self.num_layers): for c in range(num_classes): gamma_layers[i]['ss'][c] = gamma_layers[i]['ss'][c].view(num_classes, 1, 1) gamma_layers[i]['tt'][c] = gamma_layers[i]['tt'][c].view(num_classes, 1, 1) kernel_dist_st = self.kernel_layer_aggregation(dist_layers, gamma_layers, 'st') kernel_dist_st = self.patch_mean(nums_S, nums_T, kernel_dist_st) kernel_dist_ss = [] kernel_dist_tt = [] for c in range(num_classes): kernel_dist_ss += [torch.mean(self.kernel_layer_aggregation(dist_layers, gamma_layers, 'ss', c).view(num_classes, -1), dim=1)] kernel_dist_tt += [torch.mean(self.kernel_layer_aggregation(dist_layers, gamma_layers, 'tt', c).view(num_classes, -1), dim=1)] kernel_dist_ss = torch.stack(kernel_dist_ss, dim=0) kernel_dist_tt = torch.stack(kernel_dist_tt, dim=0).transpose(1, 0) mmds = kernel_dist_ss + kernel_dist_tt - 2 * kernel_dist_st intra_mmds = torch.diag(mmds, 0) intra = torch.sum(intra_mmds) / self.num_classes inter = None if not self.intra_only: inter_mask = to_cuda((torch.ones([num_classes, num_classes]) \ - torch.eye(num_classes)).type(torch.bool)) inter_mmds = torch.masked_select(mmds, inter_mask) inter = torch.sum(inter_mmds) / (self.num_classes * (self.num_classes - 1)) cdd = intra if inter is None else intra - inter return {'cdd': cdd, 'intra': intra, 'inter': inter}
def __init__(self, net, dataloader, resume=None, **kwargs): self.opt = cfg self.net = net self.init_data(dataloader) self.CEWeight = to_cuda(torch.tensor([1.0 - cfg.TRAIN.WPOS, cfg.TRAIN.WPOS])) self.CELoss = nn.CrossEntropyLoss(weight=self.CEWeight) self.BCELoss = nn.BCELoss() if torch.cuda.is_available(): self.CELoss.cuda() self.BCELoss.cuda() self.iters = 0 self.epochs = 0 self.iters_per_epoch = None self.base_lr = self.opt.TRAIN.BASE_LR self.momentum = self.opt.TRAIN.MOMENTUM self.optim_state_dict = None self.resume = False if resume is not None: self.resume = True self.epochs = resume['epochs'] self.iters = resume['iters'] self.optim_state_dict = resume['optimizer_state_dict'] print('Resume Training from iters %d, %d.' % \ (self.epochs, self.iters)) self.build_optimizer()
def G_step(self, x_S, x_T): self.set_domain_id(1) preds_D_T = self.net_D(x_T) gt_D_S = to_cuda(torch.FloatTensor(preds_D_T.size()).fill_(1.0)) loss_D = self.BCELoss(preds_D_T, gt_D_S) return loss_D
def patch_gamma_estimation(self, nums_S, nums_T, dist): assert (len(nums_S) == len(nums_T)) num_classes = len(nums_S) patch = {} gammas = {} gammas['st'] = to_cuda( torch.zeros_like(dist['st'], requires_grad=False)) gammas['ss'] = [] gammas['tt'] = [] for c in range(num_classes): gammas['ss'] += [ to_cuda(torch.zeros([num_classes], requires_grad=False)) ] gammas['tt'] += [ to_cuda(torch.zeros([num_classes], requires_grad=False)) ] source_start = source_end = 0 for ns in range(num_classes): source_start = source_end source_end = source_start + nums_S[ns] patch['ss'] = dist['ss'][ns] target_start = target_end = 0 for nt in range(num_classes): target_start = target_end target_end = target_start + nums_T[nt] patch['tt'] = dist['tt'][nt] patch['st'] = dist['st'].narrow(0, source_start, nums_S[ns]).narrow( 1, target_start, nums_T[nt]) gamma = self.gamma_estimation(patch) gammas['ss'][ns][nt] = gamma gammas['tt'][nt][ns] = gamma gammas['st'][source_start:source_end, \ target_start:target_end] = gamma return gammas
def test(self): # self.net.eval() self.feature_extractor.eval() self.classifier.eval() preds = [] gts = [] for sample in iter(self.test_data['loader']): data, gt = to_cuda(sample['Img']), to_cuda(sample['Label']) # logits = self.net(data)['logits'] feature1, feature2 = self.feature_extractor(data) # feature1 = nn.AdaptiveAvgPool2d((1, 1))(feature1).view(-1, 2048) logits = self.classifier(feature1) preds += [logits] gts += [gt] preds = torch.cat(preds, dim=0) gts = torch.cat(gts, dim=0) res = self.model_eval(preds, gts) return res
def test(self): self.set_domain_id(1) self.net.eval() num_classes = cfg.DATASET.NUM_CLASSES conmat = gen_utils.ConfusionMatrix(num_classes) for sample in iter(self.test_data['loader']): data, gt = gen_utils.to_cuda(sample['Img']), gen_utils.to_cuda( sample['Label']) logits = self.net(data)['out'] logits = F.interpolate(logits, size=gt.shape[-2:], mode='bilinear', align_corners=False) preds = torch.max(logits, dim=1).indices conmat.update(gt.flatten(), preds.flatten()) conmat.reduce_from_all_processes() accu, _, iou = conmat.compute() return accu.item() * 100.0, iou.mean().item() * 100.0
def get_centers(feature_extractor, dataloader, num_classes, key='feat'): centers = 0 refs = to_cuda(torch.LongTensor(range(num_classes)).unsqueeze(1)) for sample in iter(dataloader): data = to_cuda(sample['Img']) gt = to_cuda(sample['Label']) batch_size = data.size(0) # output = net.forward(data)[key] # feature = output.data feature, _ = feature_extractor(data) # feature = nn.AvgPool2d(7, stride=1) feature = nn.AdaptiveAvgPool2d((1, 1))(feature).view(-1, 2048) feature = feature.data feat_len = feature.size(1) gt = gt.unsqueeze(0).expand(num_classes, -1) mask = (gt == refs).unsqueeze(2).type(torch.cuda.FloatTensor) feature = feature.unsqueeze(0) # update centers centers += torch.sum(feature * mask, dim=1) return centers
def collect_samples(self, net, loader): data_feat, data_gt, data_paths = [], [], [] for sample in iter(loader): data = sample['Img'].cuda() data_paths += sample['Path'] if 'Label' in sample.keys(): data_gt += [to_cuda(sample['Label'])] output = net.forward(data) feature = output[self.feat_key].data data_feat += [feature] self.samples['data'] = data_paths self.samples['gt'] = torch.cat(data_gt, dim=0) \ if len(data_gt)>0 else None self.samples['feature'] = torch.cat(data_feat, dim=0)
def compute_kernel_dist(self, dist, gamma, kernel_num, kernel_mul): base_gamma = gamma / (kernel_mul**(kernel_num // 2)) gamma_list = [base_gamma * (kernel_mul**i) for i in range(kernel_num)] gamma_tensor = to_cuda(torch.tensor(gamma_list)) eps = 1e-5 gamma_mask = (gamma_tensor < eps).type(torch.cuda.FloatTensor) gamma_tensor = (1.0 - gamma_mask) * gamma_tensor + gamma_mask * eps gamma_tensor = gamma_tensor.detach() dist = dist.unsqueeze(0) / gamma_tensor.view(-1, 1, 1) upper_mask = (dist > 1e5).type(torch.cuda.FloatTensor).detach() lower_mask = (dist < 1e-5).type(torch.cuda.FloatTensor).detach() normal_mask = 1.0 - upper_mask - lower_mask dist = normal_mask * dist + upper_mask * 1e5 + lower_mask * 1e-5 kernel_val = torch.sum(torch.exp(-1.0 * dist), dim=0) return kernel_val
def patch_mean(self, nums_row, nums_col, dist): assert(len(nums_row) == len(nums_col)) num_classes = len(nums_row) mean_tensor = to_cuda(torch.zeros([num_classes, num_classes])) row_start = row_end = 0 for row in range(num_classes): row_start = row_end row_end = row_start + nums_row[row] col_start = col_end = 0 for col in range(num_classes): col_start = col_end col_end = col_start + nums_col[col] val = torch.mean(dist.narrow(0, row_start, nums_row[row]).narrow(1, col_start, nums_col[col])) mean_tensor[row, col] = val return mean_tensor
def collect_samples(self, feature_extractor, loader): data_feat, data_gt, data_paths = [], [], [] for sample in iter(loader): data = sample['Img'].cuda() data_paths += sample['Path'] if 'Label' in sample.keys(): data_gt += [to_cuda(sample['Label'])] # output = net.forward(data) # feature = output[self.feat_key].data feature, _ = feature_extractor(data) feature = nn.AdaptiveAvgPool2d((1, 1))(feature).view(-1, 2048) feature = feature.data data_feat += [feature] self.samples['data'] = data_paths self.samples['gt'] = torch.cat(data_gt, dim=0) \ if len(data_gt) > 0 else None self.samples['feature'] = torch.cat(data_feat, dim=0)
def step(opt, data_loader, model, to_train=True, optimizer=None): """ Used as a trining step or validation step """ nIters = len(data_loader) loss_meter = AverageMeter() with tqdm(total=nIters) as t: for i, data in enumerate(data_loader): # ===================forward===================== if opt.toCuda: data = to_cuda(data, device()) image = data.pop('image') out_dict = model(image, data) loss = out_dict['loss'] # ===================backward==================== if to_train: optimizer.zero_grad() loss.backward() optimizer.step() loss_meter.update(loss.detach().cpu().item(), image.size(0)) t.set_postfix(loss='{:10.8f}'.format(loss_meter.avg)) t.update() return loss_meter.avg
def update_network(self): stop = False update_iters = 0 while not stop: self.net.train() self.net.zero_grad() if self.opt.TRAIN.OPTIMIZER != "Adam": self.update_lr() # get the video clip and corresponding mask #start = time() _, _, vclip, vlabel = self.get_samples() #end = time() #print('Time: %f' % (end-start)) vclip = to_cuda(vclip) vlabel = to_cuda(vlabel) # forward and get the predictions # N x C x D x H x W #vout, vout_aux = self.net(vclip) vout = self.net(vclip) vprobs = F.softmax(vout, dim=1) #vout_aux = F.interpolate(vout_aux, scale_factor=(2, 4, 4)) #vprobs_aux = F.softmax(vout_aux, dim=1) ## reshape and compute the cross-entropy loss #ch = vout.size(1) #vpreds = vout.transpose(0, 1).reshape(ch, -1).transpose(0, 1).squeeze(-1) ##vout0 = F.interpolate(vout0, scale_factor=(1, 2, 2)) ##small_vpreds = vout0.transpose(0, 1).reshape(ch, -1).transpose(0, 1) #vgt = vlabel.view(-1) #loss = self.CELoss(vpreds, vgt) #self.BCELoss(vpreds, vgt) alpha = 0.3 loss = (1.0 - alpha) * solver_utils.dice_loss(vprobs, vlabel) loss += alpha * solver_utils.BF_loss(vprobs, vlabel) #loss_aux = (1.0 - alpha) * solver_utils.dice_loss(vprobs_aux, vlabel) #loss_aux += alpha * solver_utils.BF_loss(vprobs_aux, vlabel) #beta = 0.5 #loss = beta * loss + (1.0 - beta) * loss_aux # downsample the mask by scale 2 #small_mask = F.interpolate(vlabel.type(torch.cuda.FloatTensor), scale_factor=(0.5, 0.5)) #small_vgt = small_mask.view(-1).type(torch.cuda.LongTensor) #loss += self.CELoss(small_vpreds, vgt) loss.backward() self.optimizer.step() if self.opt.TRAIN.LOGGING and (update_iters+1) % \ (max(1, self.iters_per_epoch // self.opt.TRAIN.NUM_LOGGING_PER_EPOCH)) == 0: iou, iou_pos, accuracy = self.model_eval(vout, vlabel) print('Training at (epoch %d, iters: %d) with loss, iou, iou_pos, accuracy: %.4f, %.4f, %.4f, %.4f.' % ( self.epochs, self.iters, loss, iou, iou_pos, accuracy)) if self.opt.TRAIN.TEST_INTERVAL > 0 and \ (self.iters+1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_epoch) == 0: with torch.no_grad(): iou, iou_pos, accuracy = self.test() print('Test at (epoch %d, iters: %d): %.4f, %.4f, %.4f.' % (self.epochs, self.iters, iou, iou_pos, accuracy)) if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \ (self.iters+1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_epoch) == 0: self.save_ckpt() update_iters += 1 self.iters += 1 # update stop condition if update_iters >= self.iters_per_epoch: stop = True else: stop = False
def feature_clustering(self, feature_extractor, loader): centers = None self.stop = False # self.collect_samples(net, loader) self.collect_samples(feature_extractor, loader) feature = self.samples['feature'] refs = to_cuda(torch.LongTensor(range(self.num_classes)).unsqueeze(1)) num_samples = feature.size(0) # 啥意思?? num_split = ceil(1.0 * num_samples / self.max_len) while True: self.clustering_stop(centers) if centers is not None: self.centers = centers if self.stop: break centers = 0 count = 0 start = 0 for N in range(num_split): cur_len = min(self.max_len, num_samples - start) cur_feature = feature.narrow(0, start, cur_len) dist2center, labels = self.assign_labels(cur_feature) labels_onehot = to_onehot(labels, self.num_classes) count += torch.sum(labels_onehot, dim=0) labels = labels.unsqueeze(0) mask = (labels == refs).unsqueeze(2).type( torch.cuda.FloatTensor) reshaped_feature = cur_feature.unsqueeze(0) # update centers centers += torch.sum(reshaped_feature * mask, dim=1) start += cur_len mask = (count.unsqueeze(1) > 0).type(torch.cuda.FloatTensor) centers = mask * centers + (1 - mask) * self.init_centers dist2center, labels = [], [] start = 0 count = 0 for N in range(num_split): cur_len = min(self.max_len, num_samples - start) cur_feature = feature.narrow(0, start, cur_len) cur_dist2center, cur_labels = self.assign_labels(cur_feature) labels_onehot = to_onehot(cur_labels, self.num_classes) count += torch.sum(labels_onehot, dim=0) dist2center += [cur_dist2center] labels += [cur_labels] start += cur_len self.samples['label'] = torch.cat(labels, dim=0) self.samples['dist2center'] = torch.cat(dist2center, dim=0) cluster2label = self.align_centers() # reorder the centers self.centers = self.centers[cluster2label, :] # re-label the data according to the index num_samples = len(self.samples['feature']) for k in range(num_samples): self.samples['label'][k] = cluster2label[self.samples['label'] [k]].item() self.center_change = torch.mean(self.Dist.get_dist(self.centers, \ self.init_centers)) for i in range(num_samples): self.path2label[self.samples['data'] [i]] = self.samples['label'][i].item() del self.samples['feature']
def forward(self, source, target, nums_S, nums_T, domain_probs, source_sample_labels): assert(len(nums_S) == len(nums_T)), \ "The number of classes for source (%d) and target (%d) should be the same." \ % (len(nums_S), len(nums_T)) num_classes = len(nums_S) num_domains = domain_probs.size(2) # assert num_classes == domain_probs.size(1) # proper_labels = self.get_proper_labels(nums_S) proper_labels = source_sample_labels domain_probs_simple = domain_probs[torch.arange( domain_probs.size()[0]), proper_labels] # Ns x K paired_domain_probs = self.compute_paired_domain_prob( domain_probs_simple) # Ns x Ns x K paired_domain_probs_ss_classwise = self.split_paired_dp_classwise( paired_domain_probs, nums_S) # compute the dist dist_layers = [] gamma_layers = [] for i in range(self.num_layers): cur_source = source[i] cur_target = target[i] dist = {} dist['ss'] = self.compute_paired_dist(cur_source, cur_source) dist['tt'] = self.compute_paired_dist(cur_target, cur_target) dist['st'] = self.compute_paired_dist(cur_source, cur_target) dist['ss'] = self.split_classwise(dist['ss'], nums_S) dist['tt'] = self.split_classwise(dist['tt'], nums_T) # soft_dist = {} # soft_dist['ss'] = self.compute_soft_paired_dist(cur_source, cur_source, 'ss', domain_probs, paired_domain_probs) # soft_dist['tt'] = self.compute_soft_paired_dist(cur_target, cur_target, 'tt', domain_probs, paired_domain_probs) # soft_dist['st'] = self.compute_soft_paired_dist(cur_source, cur_target, 'st', domain_probs, paired_domain_probs) # soft_dist['ss'] = self.split_classwise(soft_dist['ss'], nums_S) # soft_dist['tt'] = self.split_classwise(soft_dist['tt'], nums_T) dist_layers += [dist] gamma_layers += [self.patch_gamma_estimation(nums_S, nums_T, dist)] # compute the kernel dist for i in range(self.num_layers): for c in range(num_classes): gamma_layers[i]['ss'][c] = gamma_layers[i]['ss'][c].view( num_classes, 1, 1) gamma_layers[i]['tt'][c] = gamma_layers[i]['tt'][c].view( num_classes, 1, 1) kernel_dist_st = self.kernel_layer_aggregation(dist_layers, gamma_layers, 'st') # Ns x Nt assert kernel_dist_st.size()[0] == domain_probs_simple.size()[0] kernel_dist_st_expand = kernel_dist_st.unsqueeze(2).expand( kernel_dist_st.size()[0], kernel_dist_st.size()[1], num_domains) # Ns x Nt x K domain_probs_simple_expand = domain_probs_simple.unsqueeze(1).expand( kernel_dist_st.size()[0], kernel_dist_st.size()[1], num_domains) # Ns x Nt x K kernel_dist_st_soft = kernel_dist_st_expand * domain_probs_simple_expand kernel_dist_st_soft = self.patch_mean( nums_S, nums_T, kernel_dist_st_soft, domain_probs_simple_expand) # num_classes x num_classes x K kernel_dist_ss_soft = [] kernel_dist_tt_soft = [] for c in range(num_classes): kernel_dist_ss = self.kernel_layer_aggregation( dist_layers, gamma_layers, 'ss', c) # num_classes x N_c x N_c paired_dp_ss_c = paired_domain_probs_ss_classwise[ c] # num_classes x N_c x N_c x K kernel_dist_ss_expand = kernel_dist_ss.unsqueeze(3).expand( kernel_dist_ss.size()[0], kernel_dist_ss.size()[1], kernel_dist_ss.size()[2], num_domains) # num_classes x N_c x N_c x K temp_mult = kernel_dist_ss_expand * paired_dp_ss_c # num_classes x N_c x N_c x K kernel_dist_ss_soft += [ torch.sum(temp_mult.view(num_classes, -1, num_domains), dim=1) / torch.sum(paired_dp_ss_c.view(num_classes, -1, num_domains), dim=1) ] # list of num_classes x K temp_tt = torch.mean(self.kernel_layer_aggregation( dist_layers, gamma_layers, 'tt', c).view(num_classes, -1), dim=1) kernel_dist_tt_soft += [ temp_tt.unsqueeze(1).expand(num_classes, num_domains) ] # list of num_classes x K kernel_dist_ss_soft = torch.stack(kernel_dist_ss_soft, dim=0) kernel_dist_tt_soft = torch.stack(kernel_dist_tt_soft, dim=0).transpose(1, 0) mmds = kernel_dist_ss_soft + kernel_dist_tt_soft - 2 * kernel_dist_st_soft # num_classes x num_classes x K nc2_intra = to_cuda(torch.zeros(1)) nc2_inter = to_cuda(torch.zeros(1)) nc1_intra = to_cuda(torch.zeros(1)) nc1_inter = to_cuda(torch.zeros(1)) for i in range(num_classes): for j in range(num_classes): if i == j: nc1_intra += torch.mean(kernel_dist_ss_soft[i, i] + kernel_dist_tt_soft[j, j] - 2 * kernel_dist_st_soft[i, j]) else: nc1_inter += torch.mean(kernel_dist_ss_soft[i, i] + kernel_dist_tt_soft[j, j] - 2 * kernel_dist_st_soft[i, j]) for i in range(num_classes): for j in range(num_classes): if i == j: nc2_intra += torch.mean(kernel_dist_ss_soft[i, i] + kernel_dist_ss_soft[j, j] - 2 * kernel_dist_ss_soft[i, j]) else: nc2_inter += torch.mean(kernel_dist_ss_soft[i, i] + kernel_dist_ss_soft[j, j] - 2 * kernel_dist_ss_soft[i, j]) nc1_intra = nc1_intra[0] / (self.num_classes) nc1_inter = nc1_inter[0] / (self.num_classes * (self.num_classes - 1)) nc2_intra = nc2_intra[0] / (self.num_classes) nc2_inter = nc2_inter[0] / (self.num_classes * (self.num_classes - 1)) cdd = nc1_intra + nc2_intra if self.intra_only else nc1_intra + nc2_intra - nc1_inter - nc2_inter return { 'cdd': cdd, 'intra': nc1_intra + nc2_intra, 'inter': nc1_inter + nc2_inter }
def update_network(self, filtered_classes): # initial configuration stop = False update_iters = 0 self.train_data[self.source_name]['iterator'] = \ iter(self.train_data[self.source_name]['loader']) self.train_data['categorical']['iterator'] = \ iter(self.train_data['categorical']['loader']) while not stop: # update learning rate self.update_lr() # set the status of network self.net.train() self.net.zero_grad() loss = 0 ce_loss_iter = 0 cdd_loss_iter = 0 # coventional sampling for training on labeled source data source_sample = self.get_samples(self.source_name) source_data, source_gt = source_sample['Img'], \ source_sample['Label'] source_data = to_cuda(source_data) source_gt = to_cuda(source_gt) self.net.module.set_bn_domain(self.bn_domain_map[self.source_name]) source_preds = self.net(source_data)['logits'] # compute the cross-entropy loss ce_loss = self.CELoss(source_preds, source_gt) ce_loss.backward() ce_loss_iter += ce_loss loss += ce_loss if len(filtered_classes) > 0: # update the network parameters # 1) class-aware sampling source_samples_cls, source_nums_cls, \ target_samples_cls, target_nums_cls = self.CAS() # 2) forward and compute the loss source_cls_concat = torch.cat( [to_cuda(samples) for samples in source_samples_cls], dim=0) target_cls_concat = torch.cat( [to_cuda(samples) for samples in target_samples_cls], dim=0) self.net.module.set_bn_domain( self.bn_domain_map[self.source_name]) feats_source = self.net(source_cls_concat) self.net.module.set_bn_domain( self.bn_domain_map[self.target_name]) feats_target = self.net(target_cls_concat) # prepare the features feats_toalign_S = self.prepare_feats(feats_source) feats_toalign_T = self.prepare_feats(feats_target) cdd_loss = self.cdd.forward( feats_toalign_S, feats_toalign_T, source_nums_cls, target_nums_cls)[self.discrepancy_key] cdd_loss *= self.opt.CDD.LOSS_WEIGHT cdd_loss.backward() cdd_loss_iter += cdd_loss loss += cdd_loss # update the network self.optimizer.step() if self.opt.TRAIN.LOGGING and (update_iters + 1) % \ (max(1, self.iters_per_loop // self.opt.TRAIN.NUM_LOGGING_PER_LOOP)) == 0: accu = self.model_eval(source_preds, source_gt) cur_loss = { 'ce_loss': ce_loss_iter, 'cdd_loss': cdd_loss_iter, 'total_loss': loss } self.logging(cur_loss, accu) self.opt.TRAIN.TEST_INTERVAL = min(1.0, self.opt.TRAIN.TEST_INTERVAL) self.opt.TRAIN.SAVE_CKPT_INTERVAL = min( 1.0, self.opt.TRAIN.SAVE_CKPT_INTERVAL) if self.opt.TRAIN.TEST_INTERVAL > 0 and \ (update_iters + 1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_loop) == 0: with torch.no_grad(): self.net.module.set_bn_domain( self.bn_domain_map[self.target_name]) accu = self.test() print('Test at (loop %d, iters: %d) with %s: %.4f.' % (self.loop, self.iters, self.opt.EVAL_METRIC, accu)) if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \ (update_iters + 1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_loop) == 0: self.save_ckpt() update_iters += 1 self.iters += 1 # update stop condition if update_iters >= self.iters_per_loop: stop = True else: stop = False
def test(args): # prepare data dataloader = prepare_data() # initialize model model_state_dict = None fx_pretrained = True bn_domain_map = {} if cfg.WEIGHTS != '': weights_dict = torch.load(cfg.WEIGHTS) model_state_dict = weights_dict['weights'] bn_domain_map = weights_dict['bn_domain_map'] fx_pretrained = False if args.adapted_model: num_domains_bn = 2 else: num_domains_bn = 1 net = model.danet(num_classes=cfg.DATASET.NUM_CLASSES, state_dict=model_state_dict, feature_extractor=cfg.MODEL.FEATURE_EXTRACTOR, fx_pretrained=fx_pretrained, dropout_ratio=cfg.TRAIN.DROPOUT_RATIO, fc_hidden_dims=cfg.MODEL.FC_HIDDEN_DIMS, num_domains_bn=num_domains_bn) net = torch.nn.DataParallel(net) if torch.cuda.is_available(): net.cuda() # test res = {} res['path'], res['preds'], res['gt'], res['probs'] = [], [], [], [] net.eval() if cfg.TEST.DOMAIN in bn_domain_map: domain_id = bn_domain_map[cfg.TEST.DOMAIN] else: domain_id = 0 with torch.no_grad(): net.module.set_bn_domain(domain_id) for sample in iter(dataloader): res['path'] += sample['Path'] if cfg.DATA_TRANSFORM.WITH_FIVE_CROP: n, ncrop, c, h, w = sample['Img'].size() sample['Img'] = sample['Img'].view(-1, c, h, w) img = to_cuda(sample['Img']) probs = net(img)['probs'] probs = probs.view(n, ncrop, -1).mean(dim=1) else: img = to_cuda(sample['Img']) probs = net(img)['probs'] preds = torch.max(probs, dim=1)[1] res['preds'] += [preds] res['probs'] += [probs] if 'Label' in sample: label = to_cuda(sample['Label']) res['gt'] += [label] print('Processed %d samples.' % len(res['path'])) preds = torch.cat(res['preds'], dim=0) save_preds(res['path'], preds, cfg.SAVE_DIR) if 'gt' in res and len(res['gt']) > 0: gts = torch.cat(res['gt'], dim=0) probs = torch.cat(res['probs'], dim=0) assert (cfg.EVAL_METRIC == 'mean_accu' or cfg.EVAL_METRIC == 'accuracy') if cfg.EVAL_METRIC == "mean_accu": eval_res = mean_accuracy(probs, gts) print('Test mean_accu: %.4f' % (eval_res)) elif cfg.EVAL_METRIC == "accuracy": eval_res = accuracy(probs, gts) print('Test accuracy: %.4f' % (eval_res)) print('Finished!')
def update_network(self, **kwargs): stop = False self.train_data['source']['iterator'] = iter( self.train_data['source']['loader']) self.train_data['target']['iterator'] = iter( self.train_data['target']['loader']) self.iters_per_epoch = len(self.train_data['target']['loader']) iters_counter_within_epoch = 0 data_time = AverageMeter() batch_time = AverageMeter() total_loss = AverageMeter() ce_loss = AverageMeter() da_loss = AverageMeter() prec1_task = AverageMeter() prec1_aux1 = AverageMeter() prec1_aux2 = AverageMeter() self.net.train() end = time.time() if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': lam = 2 / (1 + math.exp( -1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1 self.update_lr() print('value of lam is: %3f' % (lam)) while not stop: if self.opt.TRAIN.PROCESS_COUNTER == 'iteration': lam = 2 / (1 + math.exp( -1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1 print('value of lam is: %3f' % (lam)) self.update_lr() source_data, source_gt = self.get_samples('source') target_data, _ = self.get_samples('target') source_data = to_cuda(source_data) source_gt = to_cuda(source_gt) target_data = to_cuda(target_data) data_time.update(time.time() - end) feature_source, output_source, output_source1, output_source2, output_source_dc, output_source1_trunc, output_source2_trunc = self.net( source_data, lam) loss_task_auxiliary_1 = self.CELoss(output_source1_trunc, source_gt) loss_task_auxiliary_2 = self.CELoss(output_source2_trunc, source_gt) loss_task = self.CELoss(output_source, source_gt) if self.opt.MCDALNET.DISTANCE_TYPE != 'SourceOnly': feature_target, output_target, output_target1, output_target2, output_target_dc, output_target1_trunc, output_target2_trunc = self.net( target_data, lam) if self.opt.MCDALNET.DISTANCE_TYPE == 'DANN': num_source = source_data.size()[0] num_target = target_data.size()[0] dlabel_source = to_cuda(torch.zeros(num_source, 1)) dlabel_target = to_cuda(torch.ones(num_target, 1)) loss_domain_all = self.BCELoss( output_source_dc, dlabel_source) + self.BCELoss( output_target_dc, dlabel_target) loss_all = loss_task + loss_domain_all elif self.opt.MCDALNET.DISTANCE_TYPE == 'MDD': prob_target1 = F.softmax(output_target1, dim=1) _, target_pseudo_label = torch.topk(output_target2, 1) batch_index = torch.arange(output_target.size()[0]).long() pred_gt_prob = prob_target1[ batch_index, target_pseudo_label] ## the prob values of the predicted gt pred_gt_prob = process_one_values(pred_gt_prob) loss_domain_target = (1 - pred_gt_prob).log().mean() _, source_pseudo_label = torch.topk(output_source2, 1) loss_domain_source = self.CELoss(output_source1, source_pseudo_label[:, 0]) loss_domain_all = loss_domain_source - loss_domain_target loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2 else: loss_domain_source = self.McDalNetLoss( output_source1, output_source2, self.opt.MCDALNET.DISTANCE_TYPE) loss_domain_target = self.McDalNetLoss( output_target1, output_target2, self.opt.MCDALNET.DISTANCE_TYPE) loss_domain_all = loss_domain_source - loss_domain_target loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2 da_loss.update(loss_domain_all, source_data.size()[0]) else: loss_all = loss_task ce_loss.update(loss_task, source_data.size()[0]) total_loss.update(loss_all, source_data.size()[0]) prec1_task.update(accuracy(output_source, source_gt), source_data.size()[0]) prec1_aux1.update(accuracy(output_source1, source_gt), source_data.size()[0]) prec1_aux2.update(accuracy(output_source2, source_gt), source_data.size()[0]) self.optimizer.zero_grad() loss_all.backward() self.optimizer.step() print(" Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \ (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg)) batch_time.update(time.time() - end) end = time.time() self.iters += 1 iters_counter_within_epoch += 1 if iters_counter_within_epoch >= self.iters_per_epoch: log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') log.write("\n") log.write(" Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \ (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg)) log.close() stop = True
def test(args): # initialize model model_state_dict = None if cfg.WEIGHTS != '': param_dict = torch.load(cfg.WEIGHTS, torch.device('cpu')) model_state_dict = param_dict['weights'] net = SegNet.__dict__[cfg.MODEL.NETWORK_NAME]( pretrained=False, pretrained_backbone=False, num_classes=cfg.DATASET.NUM_CLASSES, aux_loss=cfg.MODEL.USE_AUX_CLASSIFIER) if args.distributed: net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) if cfg.MODEL.DOMAIN_BN: net = DomainBN.convert_domain_batchnorm(net, num_domains=2) if model_state_dict is not None: try: net.load_state_dict(model_state_dict) except: net = DomainBN.convert_domain_batchnorm(net, num_domains=2) net.load_state_dict(model_state_dict) if cfg.TEST.DOMAIN == 'source': DomainBN.set_domain_id(net, 0) if cfg.TEST.DOMAIN == 'target': DomainBN.set_domain_id(net, 1) if torch.cuda.is_available(): net.cuda() if args.distributed: net = DistributedDataParallel(net, device_ids=[args.gpu]) else: net = torch.nn.DataParallel(net) test_dataset, test_dataloader = prepare_data(args) net.eval() corrects = 0 total_num_pixels = 0 total_intersection = 0 total_union = 0 num_classes = cfg.DATASET.NUM_CLASSES with torch.no_grad(): conmat = gen_utils.ConfusionMatrix( cfg.DATASET.NUM_CLASSES, list(LABEL_TASK['%s2%s' % (cfg.DATASET.SOURCE, cfg.DATASET.TARGET)].keys())) for sample in iter(test_dataloader): data, gt = gen_utils.to_cuda(sample['Img']), gen_utils.to_cuda( sample['Label']) names = sample['Name'] res = net(data) if cfg.TEST.WITH_AGGREGATION: feats = res['feat'] alpha = 0.5 feats = (1.0 - alpha) * feats + alpha * AssociationLoss( ).spatial_agg(feats)[-1] preds = F.softmax(net.module.classifier(feats), dim=1) preds = (1.0 - alpha) * preds + alpha * AssociationLoss( ).spatial_agg(preds, metric='kl')[-1] else: preds = res['out'] preds = F.interpolate(preds, size=gt.shape[-2:], mode='bilinear', align_corners=False) preds = torch.max(preds, dim=1).indices if cfg.TEST.VISUALIZE: for i in range(preds.size(0)): cur_pred = preds[i, :, :].cpu().numpy() cur_gt = gt[i, :, :].cpu().numpy() cur_pred_cp = cur_pred.copy() cur_gt_cp = cur_gt.copy() label_map = label_map_gtav if cfg.DATASET.SOURCE == 'GTAV' else label_map_syn for n in range(cfg.DATASET.NUM_CLASSES): cur_pred[cur_pred_cp == n] = label_map[n] cur_gt[cur_gt_cp == n] = label_map[n] cur_pred = np.where(cur_gt == 255, cur_gt, cur_pred) cur_pred = np.asarray(cur_pred, dtype=np.uint8) cur_gt = np.asarray(cur_gt, dtype=np.uint8) vis_res = colorize_mask(cur_pred) vis_gt = colorize_mask(cur_gt) vis_name = 'vis_%s.png' % (names[i]) vis_res.save(os.path.join(cfg.SAVE_DIR, vis_name)) vis_name = 'vis_gt_%s.png' % (names[i]) vis_gt.save(os.path.join(cfg.SAVE_DIR, vis_name)) conmat.update(gt.flatten(), preds.flatten()) conmat.reduce_from_all_processes() print('Test with %d samples: ' % len(test_dataset)) print(conmat) print('Finished!')
def test(self): self.feature_extractor.eval() self.classifier.eval() prec1_fs = AverageMeter() prec1_ft = AverageMeter() counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) for i, (input, target) in enumerate(self.test_data['loader']): input, target = to_cuda(input), to_cuda(target) with torch.no_grad(): feature_test = self.feature_extractor(input) output_test = self.classifier(feature_test) if self.opt.EVAL_METRIC == 'accu': prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target) prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) prec1_fs.update(prec1_fs_iter, input.size(0)) prec1_ft.update(prec1_ft_iter, input.size(0)) if i % self.opt.PRINT_STEP == 0: print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \ (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg)) elif self.opt.EVAL_METRIC == 'accu_mean': prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) prec1_ft.update(prec1_ft_iter, input.size(0)) counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs) counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft) if i % self.opt.PRINT_STEP == 0: print(" Test:epoch: %d:[%d/%d], Task: %3f" % \ (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg)) else: raise NotImplementedError acc_for_each_class_fs = counter_acc_fs / counter_all_fs acc_for_each_class_ft = counter_acc_ft / counter_all_ft log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') log.write("\n") if self.opt.EVAL_METRIC == 'accu': log.write( " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ (self.epoch, prec1_fs.avg, prec1_ft.avg)) log.close() return max(prec1_fs.avg, prec1_ft.avg) elif self.opt.EVAL_METRIC == 'accu_mean': log.write( " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean())) log.write("\nClass-wise Acc of Ft:") ## based on the task classifier. for i in range(self.opt.DATASET.NUM_CLASSES): if i == 0: log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i])) elif i == 1: log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i])) elif i == 2: log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i])) else: log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i])) log.close() return max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean())
def test(self): self.net.eval() prec1_task = AverageMeter() prec1_auxi1 = AverageMeter() prec1_auxi2 = AverageMeter() counter_all = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) counter_all_auxi1 = torch.FloatTensor( self.opt.DATASET.NUM_CLASSES).fill_(0) counter_all_auxi2 = torch.FloatTensor( self.opt.DATASET.NUM_CLASSES).fill_(0) counter_acc = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) counter_acc_auxi1 = torch.FloatTensor( self.opt.DATASET.NUM_CLASSES).fill_(0) counter_acc_auxi2 = torch.FloatTensor( self.opt.DATASET.NUM_CLASSES).fill_(0) for i, (input, target) in enumerate(self.test_data['loader']): input, target = to_cuda(input), to_cuda(target) with torch.no_grad(): _, output_test, output_test1, output_test2, _, _, _ = self.net( input, 1) ## the value of lam do not affect the test process if self.opt.EVAL_METRIC == 'accu': prec1_task_iter = accuracy(output_test, target) prec1_auxi1_iter = accuracy(output_test1, target) prec1_auxi2_iter = accuracy(output_test2, target) prec1_task.update(prec1_task_iter, input.size(0)) prec1_auxi1.update(prec1_auxi1_iter, input.size(0)) prec1_auxi2.update(prec1_auxi2_iter, input.size(0)) if i % self.opt.PRINT_STEP == 0: print(" Test:epoch: %d:[%d/%d], Auxi1: %3f, Auxi2: %3f, Task: %3f" % \ (self.epoch, i, len(self.test_data['loader']), prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg)) elif self.opt.EVAL_METRIC == 'accu_mean': prec1_task_iter = accuracy(output_test, target) prec1_task.update(prec1_task_iter, input.size(0)) counter_all, counter_acc = accuracy_for_each_class( output_test, target, counter_all, counter_acc) counter_all_auxi1, counter_acc_auxi1 = accuracy_for_each_class( output_test1, target, counter_all_auxi1, counter_acc_auxi1) counter_all_auxi2, counter_acc_auxi2 = accuracy_for_each_class( output_test2, target, counter_all_auxi2, counter_acc_auxi2) if i % self.opt.PRINT_STEP == 0: print(" Test:epoch: %d:[%d/%d], Task: %3f" % \ (self.epoch, i, len(self.test_data['loader']), prec1_task.avg)) else: raise NotImplementedError acc_for_each_class = counter_acc / counter_all acc_for_each_class_auxi1 = counter_acc_auxi1 / counter_all_auxi1 acc_for_each_class_auxi2 = counter_acc_auxi2 / counter_all_auxi2 log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') log.write("\n") if self.opt.EVAL_METRIC == 'accu': log.write( " Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \ (self.epoch, prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg)) log.close() return max(prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg) elif self.opt.EVAL_METRIC == 'accu_mean': log.write( " Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \ (self.epoch, acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean())) log.write("\nClass-wise Acc:") ## based on the task classifier. for i in range(self.opt.DATASET.NUM_CLASSES): if i == 0: log.write("%dst: %3f" % (i + 1, acc_for_each_class[i])) elif i == 1: log.write(", %dnd: %3f" % (i + 1, acc_for_each_class[i])) elif i == 2: log.write(", %drd: %3f" % (i + 1, acc_for_each_class[i])) else: log.write(", %dth: %3f" % (i + 1, acc_for_each_class[i])) log.close() return max(acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean())
def update_network(self, **kwargs): stop = False self.train_data['source']['iterator'] = iter(self.train_data['source']['loader']) self.train_data['target']['iterator'] = iter(self.train_data['target']['loader']) self.iters_per_epoch = len(self.train_data['target']['loader']) iters_counter_within_epoch = 0 data_time = AverageMeter() batch_time = AverageMeter() classifier_loss = AverageMeter() feature_extractor_loss = AverageMeter() prec1_fs = AverageMeter() prec1_ft = AverageMeter() self.feature_extractor.train() self.classifier.train() end = time.time() if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1 self.update_lr() print('value of lam is: %3f' % (lam)) while not stop: if self.opt.TRAIN.PROCESS_COUNTER == 'iteration': lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1 print('value of lam is: %3f' % (lam)) self.update_lr() source_data, source_gt = self.get_samples('source') target_data, _ = self.get_samples('target') source_data = to_cuda(source_data) source_gt = to_cuda(source_gt) target_data = to_cuda(target_data) data_time.update(time.time() - end) feature_source = self.feature_extractor(source_data) output_source = self.classifier(feature_source) feature_target = self.feature_extractor(target_data) output_target = self.classifier(feature_target) loss_task_fs = self.CELoss(output_source[:,:self.num_classes], source_gt) loss_task_ft = self.CELoss(output_source[:,self.num_classes:], source_gt) loss_discrim_source = self.CELoss(output_source, source_gt) loss_discrim_target = self.TargetDiscrimLoss(output_target) loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target source_gt_for_ft_in_fst = source_gt + self.num_classes loss_confusion_source = 0.5 * self.CELoss(output_source, source_gt) + 0.5 * self.CELoss(output_source, source_gt_for_ft_in_fst) loss_confusion_target = self.ConcatenatedCELoss(output_target) loss_summary_feature_extractor = loss_confusion_source + lam * loss_confusion_target self.optimizer_classifier.zero_grad() loss_summary_classifier.backward(retain_graph=True) self.optimizer_classifier.step() self.optimizer_feature_extractor.zero_grad() loss_summary_feature_extractor.backward() self.optimizer_feature_extractor.step() classifier_loss.update(loss_summary_classifier, source_data.size()[0]) feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0]) prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0]) prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0]) print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) batch_time.update(time.time() - end) end = time.time() self.iters += 1 iters_counter_within_epoch += 1 if iters_counter_within_epoch >= self.iters_per_epoch: log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') log.write("\n") log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) log.close() stop = True
def test_and_save_mask(net, test_dataloader): clip_len = cfg.DATASET.CLIP_LEN clip_stride = cfg.DATASET.CLIP_STRIDE for sample in iter(test_dataloader): if cfg.TEST.WITH_MASK: vid, start_f, clips, vmask = sample else: vid, start_f, clips = sample if clips.size(2) < 8: continue # forward and get the prediction result vpred = net(to_cuda(clips)) # N x D x H x W probs = F.softmax(vpred, dim=1) pos_probs = probs[:, 1, :, :, :] start_f = start_f.numpy() N = len(vid) assert(N == len(start_f)) for i in range(N): cur_vid = vid[i] cur_video_path = os.path.join(cfg.DATASET.DATAROOT, 'videos', '%s.%s'%(cur_vid, cfg.DATASET.VIDEO_FORMAT)) print('video: %s, start_f: %d' % (cur_video_path, start_f[i])) # TODO: for debugging #if start_f[i] < 5344: # continue cur_video = VideoReader(cur_video_path) #cur_video.seek(int(start_f[i])) frame_count = 0 proposals = [] clip_imgs = [] #masks = [] #for frame in cur_video.get_iter(clip_len * clip_stride): # if frame_count % clip_stride: # frame_count += 1 # continue # # read the image # img = frame.numpy() # assert(len(img.shape) > 1) # clip_imgs += [img] # #img = img[:, :, [2, 1, 0]] # #img = (img / 255.) * 2 - 1 for fid in range(clip_len * clip_stride): if frame_count % clip_stride: frame_count += 1 continue count = frame_count // clip_stride if not cfg.TEST.WITH_DENSE_CRF: cur_pos_probs = pos_probs[i, count, :, :].cpu().numpy() else: cur_probs = probs[i, :, count, :, :].cpu().numpy() # TODO: need normalize or not? resized_img = clips[i, :, count, :, :].cpu().numpy() resized_img = np.uint8(255 * (resized_img + 1.) / 2.0) cur_pos_probs = 1.0 * dense_crf(cur_probs, resized_img) smoothing(cur_pos_probs) labels, num_regions = regional_growing(cur_pos_probs, pixel_val_thres=0.3) cur_pos_probs, bboxes = filtering(cur_pos_probs, labels, num_regions, 5) #masks.append(cur_pos_probs) if len(proposals) == 0: proposals = [[(count, bbox)] for bbox in bboxes] else: associate_bboxes(count, bboxes, proposals) frame_count += 1 #heatmaps = draw_heatmaps(clip_imgs, masks) #save_visualizations(heatmaps, 'heatmaps', cur_vid, start_f[i]) #h, w, _ = clip_imgs[0].shape h, w = cur_video.height, cur_video.width new_proposals = [prop for prop in proposals if len(prop) >= 7] print('Number of proposals before and after filtering: %d, %d' % (len(proposals), len(new_proposals))) if len(new_proposals) == 0: continue #for prop in new_proposals: # print(prop) stride_x = 1.0 * w / cfg.DATA_TRANSFORM.FINESIZE stride_y = 1.0 * h / cfg.DATA_TRANSFORM.FINESIZE new_proposals = resize_proposals(new_proposals, stride_x, stride_y, w, h) save_proposals(new_proposals, 'proposals', cur_vid, start_f[i])
def update_network(self, filtered_classes): # initial configuration stop = False update_iters = 0 self.train_data[self.source_name]['iterator'] = \ iter(self.train_data[self.source_name]['loader']) self.train_data['categorical']['iterator'] = \ iter(self.train_data['categorical']['loader']) while not stop: # update learning rate self.update_lr() # set the status of network self.net.train() self.dpn.train() self.net.zero_grad() self.dpn.zero_grad() loss = 0 ce_loss_iter = 0 cdd_loss_iter = 0 # coventional sampling for training on labeled source data source_sample = self.get_samples(self.source_name) source_data, source_gt = source_sample['Img'],\ source_sample['Label'] source_data = to_cuda(source_data) source_gt = to_cuda(source_gt) self.net.module.set_bn_domain(self.bn_domain_map[self.source_name]) source_preds = self.net(source_data)['logits'] # compute the cross-entropy loss ce_loss = self.CELoss(source_preds, source_gt) ce_loss.backward() ce_loss_iter += ce_loss loss += ce_loss if len(filtered_classes) > 0: # update the network parameters # 1) class-aware sampling source_samples_cls, source_nums_cls, \ target_samples_cls, target_nums_cls, source_sample_labels = self.CAS() source_sample_labels = torch.cat(source_sample_labels, dim=0).cuda() # 2) forward and compute the loss source_cls_concat = torch.cat( [to_cuda(samples) for samples in source_samples_cls], dim=0) target_cls_concat = torch.cat( [to_cuda(samples) for samples in target_samples_cls], dim=0) self.net.module.set_bn_domain( self.bn_domain_map[self.source_name]) feats_source = self.net(source_cls_concat) self.net.module.set_bn_domain( self.bn_domain_map[self.target_name]) feats_target = self.net(target_cls_concat) # prepare the features feats_toalign_S = self.prepare_feats(feats_source) feats_toalign_T = self.prepare_feats(feats_target) domain_logits = self.dpn(source_cls_concat) domain_logits = domain_logits.reshape( domain_logits.shape[0], self.opt.DATASET.NUM_CLASSES, -1) domain_prob_s = torch.zeros(domain_logits.shape, dtype=torch.float32).cuda() kl_loss = 0 entropy_loss = 0 num_active_classes = 0 for i in range(self.opt.DATASET.NUM_CLASSES): indexes = source_sample_labels == i if indexes.sum() == 0: continue entropy_loss_cl, domain_prob_s_cl = self.entropy_loss( domain_logits[indexes, i]) kl_loss += -self.get_domain_entropy(domain_prob_s_cl) entropy_loss += entropy_loss_cl domain_prob_s[indexes, i] = domain_prob_s_cl num_active_classes += 1 entropy_loss = entropy_loss * self.clustering_wt / num_active_classes kl_loss = kl_loss * self.clustering_wt / num_active_classes cdd_loss = self.cdd.forward( feats_toalign_S, feats_toalign_T, source_nums_cls, target_nums_cls, domain_prob_s, source_sample_labels)[self.discrepancy_key] total_loss = cdd_loss * self.opt.CDD.LOSS_WEIGHT + entropy_loss + kl_loss total_loss.backward() print("Entropy loss:", entropy_loss, "KL_loss:", kl_loss) cdd_loss_iter += total_loss loss += total_loss # update the network self.optimizer.step() if self.opt.TRAIN.LOGGING and (update_iters+1) % \ (max(1, self.iters_per_loop // self.opt.TRAIN.NUM_LOGGING_PER_LOOP)) == 0: accu = self.model_eval(source_preds, source_gt) cur_loss = { 'ce_loss': ce_loss_iter, 'cdd_loss': cdd_loss_iter, 'total_loss': loss } self.logging(cur_loss, accu) self.opt.TRAIN.TEST_INTERVAL = min(1.0, self.opt.TRAIN.TEST_INTERVAL) self.opt.TRAIN.SAVE_CKPT_INTERVAL = min( 1.0, self.opt.TRAIN.SAVE_CKPT_INTERVAL) if self.opt.TRAIN.TEST_INTERVAL > 0 and \ (update_iters+1) % int(self.opt.TRAIN.TEST_INTERVAL * self.iters_per_loop) == 0: with torch.no_grad(): self.net.module.set_bn_domain( self.bn_domain_map[self.target_name]) accu = self.test() print('Test at (loop %d, iters: %d) with %s: %.4f.' % (self.loop, self.iters, self.opt.EVAL_METRIC, accu)) if self.opt.TRAIN.SAVE_CKPT_INTERVAL > 0 and \ (update_iters+1) % int(self.opt.TRAIN.SAVE_CKPT_INTERVAL * self.iters_per_loop) == 0: self.save_ckpt() update_iters += 1 self.iters += 1 # update stop condition if update_iters >= self.iters_per_loop: stop = True else: stop = False