def sample_select(softmax_out_target, features_tar, ad_net, iter): # with torch.no_grad(): ad_out = ad_net(features_tar, test=True).view(-1) entropy = losssntg.Entropy(softmax_out_target) entropy = 1.0 + torch.exp(-entropy) threshold = 1.6 - 0.1 * iter / 2000 # w_e = torch.threshold(entropy,threshold,0) # return ad_out*w_e w_e = torch.threshold(entropy, threshold, -1) w_e = torch.threshold(-w_e, 0, 0) return 1 - w_e
def cal_pseudolabel_w(loader, ad_net, base_net): iter_tar = iter(loader['target']) ad_net.train(False) base_net.train(False) with torch.no_grad(): D_output_false = [] F_output_false = [] D_output_true = [] F_output_true = [] F_entropy_false = [] F_max_false = [] F_entropy_true = [] F_max_true = [] for i in range(len(loader['target'])): inputs_tar, labels_tar = iter_tar.next() inputs_tar, labels_tar = inputs_tar.cuda(), labels_tar.cuda() features_tar, outputs_tar = base_net(inputs_tar) softmax_out_target = nn.Softmax(dim=1)(outputs_tar) pseudo_label = torch.max(softmax_out_target, 1)[1] # ad_out = ad_net(features_tar,test=True) entropy = losssntg.Entropy(softmax_out_target) entropy = 1.0 + torch.exp(-entropy) pred = (labels_tar == pseudo_label) for j in range(inputs_tar.size(0)): if pred[j].item() == 0: # D_output_false.append(ad_out[j].item()) F_entropy_false.append(entropy[j].item()) F_max_false.append(softmax_out_target[j].max().item()) elif pred[j].item() == 1: # D_output_true.append(ad_out[j].item()) F_entropy_true.append(entropy[j].item()) F_max_true.append(softmax_out_target[j].max().item()) return { 'D_output_false': D_output_false, 'D_output_true': D_output_true, 'F_entropy_false': F_entropy_false, 'F_entropy_true': F_entropy_true, 'F_max_false': F_max_false, 'F_max_true': F_max_true }
def sample_select(softmax_out_target, iter, type): # with torch.no_grad(): # ad_out = ad_net(features_tar, test=True).view(-1) if type == 1: entropy = losssntg.Entropy(softmax_out_target) conf = torch.exp(-entropy) th = (0.3 * conf.median() + 0.7 * conf.max()) # th = conf.mean() elif type == 2: conf = softmax_out_target.max(dim=1)[0] th = (0.3 * conf.median() + 0.7 * conf.max()) # th = conf.mean() # w = torch.where(conf<th, torch.full_like(w, 0),w) sel = torch.where(conf < th, torch.full_like(conf, 0), torch.full_like(conf, 1)) w = ((conf / (1 - th)) * sel).cuda().detach() return w, sel.sum()
def train(config): ## set pre-process prep_dict = {} prep_config = config["prep"] prep_dict["source"] = prep.image_train(**config["prep"]['params']) prep_dict["target"] = prep.image_train(**config["prep"]['params']) if prep_config["test_10crop"]: prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params']) else: prep_dict["test"] = prep.image_test(**config["prep"]['params']) ## prepare data dsets = {} dset_loaders = {} data_config = config["data"] train_bs = data_config["source"]["batch_size"] test_bs = data_config["test"]["batch_size"] dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \ transform=prep_dict["source"]) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \ shuffle=True, num_workers=0, drop_last=True) dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \ transform=prep_dict["target"]) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \ shuffle=True, num_workers=0, drop_last=True) if prep_config["test_10crop"]: for i in range(10): dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \ transform=prep_dict["test"][i]) for i in range(10)] dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \ shuffle=False, num_workers=0) for dset in dsets['test']] else: dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \ transform=prep_dict["test"]) dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \ shuffle=False, num_workers=0) # class_num = config["network"]["params"]["class_num"] n_class = config["network"]["params"]["class_num"] ## set base network net_config = config["network"] base_network_stu = net_config["name"](**net_config["params"]).cuda() base_network_tea = net_config["name"](**net_config["params"]).cuda() ## add additional network for some methods if config["loss"]["random"]: random_layer = network.RandomLayer([base_network_stu.output_num(), n_class], config["loss"]["random_dim"]) ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024) else: random_layer = None if config['method'] == 'DANN': ad_net = network.AdversarialNetwork(base_network_stu.output_num(), 1024)#DANN else: ad_net = network.AdversarialNetwork(base_network_stu.output_num() * n_class, 1024) if config["loss"]["random"]: random_layer.cuda() ad_net = ad_net.cuda() ad_net2 = network.AdversarialNetwork(n_class, n_class*4) ad_net2.cuda() parameter_list = base_network_stu.get_parameters() + ad_net.get_parameters() teacher_params = list(base_network_tea.parameters()) for param in teacher_params: param.requires_grad = False ## set optimizer optimizer_config = config["optimizer"] optimizer = optimizer_config["type"](parameter_list, \ **(optimizer_config["optim_params"])) teacher_optimizer = EMAWeightOptimizer(base_network_tea, base_network_stu, alpha=0.99) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) best_acc = 0.0 output1(log_name) loss1, loss2, loss3, loss4, loss5,loss6 = 0, 0, 0, 0, 0,0 output1(' ======= DA TRAINING ======= ') best1 = 0 f_t_result = [] max_iter = config["num_iterations"] for i in range(max_iter+1): if i % config["test_interval"] == config["test_interval"] - 1 and i > 1500: base_network_tea.train(False) base_network_stu.train(False) # print("test") if 'MT' in config['method']: temp_acc = image_classification_test(dset_loaders, base_network_tea, test_10crop=prep_config["test_10crop"]) if temp_acc > best_acc: best_acc = temp_acc log_str = "iter: {:05d}, tea_precision: {:.5f}".format(i, temp_acc) output1(log_str) # # if i > 20001 and best_acc < 0.69: # break # # if temp_acc < 0.67: # break # # if i > 30001 and best_acc < 0.71: # break # torch.save(base_network_tea, osp.join(path, "_model.pth.tar")) # temp_acc = image_classification_test(dset_loaders, # base_network_stu, test_10crop=prep_config["test_10crop"]) if temp_acc > best_acc: best_acc = temp_acc torch.save(base_network_stu, osp.join(path, "_model.pth.tar")) # log_str = "iter: {:05d}, stu_precision: {:.5f}".format(i, temp_acc) # output1(log_str) else: temp_acc = image_classification_test(dset_loaders, base_network_stu, test_10crop=prep_config["test_10crop"]) if temp_acc > best_acc: best_acc = temp_acc torch.save(base_network_stu, osp.join(path,"_model.pth.tar")) log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc) output1(log_str) loss_params = config["loss"] ## train one iter base_network_stu.train(True) base_network_tea.train(True) ad_net.train(True) optimizer = lr_scheduler(optimizer, i, **schedule_param) optimizer.zero_grad() if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda() features_source, outputs_source = base_network_stu(inputs_source) features_target_stu, outputs_target_stu = base_network_stu(inputs_target) features_target_tea, outputs_target_tea = base_network_tea(inputs_target) softmax_out_source = nn.Softmax(dim=1)(outputs_source) softmax_out_target_stu = nn.Softmax(dim=1)(outputs_target_stu) softmax_out_target_tea = nn.Softmax(dim=1)(outputs_target_tea) features = torch.cat((features_source, features_target_stu), dim=0) if 'MT' in config['method']: softmax_out = torch.cat((softmax_out_source, softmax_out_target_tea), dim=0) else: softmax_out = torch.cat((softmax_out_source, softmax_out_target_stu), dim=0) vat_loss = VAT(base_network_stu).cuda() n, d = features_source.shape decay = cal_decay(start=1,end=0.6,i = i) # image number in each class s_labels = labels_source t_max, t_labels = torch.max(softmax_out_target_tea, 1) t_max, t_labels = t_max.cuda(), t_labels.cuda() if config['method'] == 'DANN+dis' or config['method'] == 'CDRL': pass elif config['method'] == 'RESNET': classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = classifier_loss elif config['method'] == 'CDAN+E': entropy = losssntg.Entropy(softmax_out) ad_loss = losssntg.CDANori([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer) transfer_loss = loss_params["trade_off"] * ad_loss classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = transfer_loss + classifier_loss elif config['method'] == 'CDAN': ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer) transfer_loss = loss_params["trade_off"] * ad_loss classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = transfer_loss + classifier_loss elif config['method'] == 'DANN': ad_loss = losssntg.DANN(features, ad_net) transfer_loss = loss_params["trade_off"] * ad_loss classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = transfer_loss + classifier_loss elif config['method'] == 'CDAN+MT': th = config['th'] ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer) unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th) unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th) transfer_loss = loss_params["trade_off"] * ad_loss \ + 0.01 * unsup_loss classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = transfer_loss + classifier_loss elif config['method'] == 'CDAN+MT+VAT': cent = ConditionalEntropyLoss().cuda() ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer) unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class) loss_trg_cent = 1e-2 * cent(outputs_target_stu) loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu) transfer_loss = loss_params["trade_off"] * ad_loss \ + 0.001*(unsup_loss + loss_trg_cent + loss_trg_vat) classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = transfer_loss + classifier_loss elif config['method'] == 'CDAN+MT+cent+VAT+temp': th = 0.7 ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer) unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th) cent = ConditionalEntropyLoss().cuda() # loss_src_vat = vat_loss(inputs_source, outputs_source) loss_trg_cent = 1e-2 * cent(outputs_target_stu) loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu) transfer_loss = loss_params["trade_off"] * ad_loss \ + unsup_loss + loss_trg_cent + loss_trg_vat # temperature classifier_loss = nn.NLLLoss()(nn.LogSoftmax(1)(outputs_source / 1.05), labels_source) total_loss = transfer_loss + classifier_loss elif config['method'] == 'CDAN+MT+cent+VAT+weightCross+T': if i % len_train_target == 0: if i != 0: # print(cnt) cnt = torch.tensor(cnt).float() weight = cnt.sum() - cnt weight = weight.cuda() else: weight = torch.ones(n_class).cuda() cnt = [0] * n_class for j in t_labels: cnt[j.item()] += 1 a = config['a'] b = config['b'] th = config['th'] temp = config['temp'] ad_loss = losssntg.CDANori([features, softmax_out], ad_net, None, None, random_layer) unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th) # cbloss = compute_cbloss(softmax_out_target_stu, n_class, cls_balance=0.05) # unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th) cent = ConditionalEntropyLoss().cuda() # loss_src_vat = vat_loss(inputs_source, outputs_source) loss_trg_cent = 1e-2 * cent(outputs_target_stu) loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu) classifier_loss = nn.CrossEntropyLoss(weight=weight)(outputs_source/temp, labels_source) # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) transfer_loss = loss_params["trade_off"] * ad_loss \ + a*unsup_loss + b*(loss_trg_vat+loss_trg_cent) total_loss = transfer_loss + classifier_loss elif config['method'] == 'CDAN+MT+E+VAT+weightCross+T': entropy = losssntg.Entropy(softmax_out) if i % len_train_target == 0: if i != 0: # print(cnt) cnt = torch.tensor(cnt).float() weight = cnt.sum() - cnt weight = weight.cuda() else: weight = torch.ones(n_class).cuda() cnt = [0] * n_class for j in t_labels: cnt[j.item()] += 1 a = config['a'] b = config['b'] th = config['th'] temp = config['temp'] ad_loss = losssntg.CDANori([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer) # unsup_loss = compute_aug_loss2(softmax_out_target_stu, softmax_out_target_tea, n_class,confidence_thresh=th) # cbloss = compute_cbloss(softmax_out_target_stu, n_class, cls_balance=0.05) unsup_loss = compute_aug_loss(softmax_out_target_stu, softmax_out_target_tea, n_class, confidence_thresh=th) cent = ConditionalEntropyLoss().cuda() # loss_src_vat = vat_loss(inputs_source, outputs_source) loss_trg_cent = 1e-2 * cent(outputs_target_stu) loss_trg_vat = 1e-2 * vat_loss(inputs_target, outputs_target_stu) classifier_loss = nn.CrossEntropyLoss(weight=weight)(outputs_source / temp, labels_source) # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) transfer_loss = loss_params["trade_off"] * ad_loss \ + a * unsup_loss + b * (loss_trg_vat + loss_trg_cent) total_loss = transfer_loss + classifier_loss # adout # ad_out1 = ad_net(features_source) # w = 1-ad_out1 # c = w * nn.CrossEntropyLoss(reduction='none')(outputs_source, labels_source) # classifier_loss = c.mean() # total_loss = transfer_loss + classifier_loss total_loss.backward() optimizer.step() teacher_optimizer.step() loss1 += ad_loss.item() loss2 += classifier_loss.item() # loss3 += unsup_loss.item() # loss4 += loss_trg_cent.item() # loss5 += loss_trg_vat.item() # loss6 += cbloss.item() # dis_sloss_l += sloss_l.item() if i % 50 == 0 and i != 0: output1('iter:{:d}, ad_loss_D:{:.2f}, closs:{:.2f}, unsup_loss:{:.2f}, loss_trg_cent:{:.2f}, loss_trg_vatcd:{:.2f}, cbloss:{:.2f}' .format(i, loss1, loss2, loss3, loss4, loss5,loss6)) loss1, loss2, loss3, loss4, loss5, loss6 = 0, 0, 0, 0, 0, 0 # torch.save(best_model, osp.join(path, "best_model.pth.tar")) return best_acc