def train_pada(config): if config['network'] == 'inceptionv1': extractor_s = InceptionV1(num_classes=32) extractor_t = InceptionV1(num_classes=32) elif config['network'] == 'inceptionv1s': extractor_s = InceptionV1s(num_classes=32) extractor_t = InceptionV1s(num_classes=32) else: extractor_s = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) extractor_t = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier_s = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) classifier_t = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) if torch.cuda.is_available(): extractor_s = extractor_s.cuda() classifier_s = classifier_s.cuda() extractor_t = extractor_t.cuda() classifier_t = classifier_t.cuda() cdan_random = config['random_layer'] res_dir = os.path.join( config['res_dir'], 'normal{}-{}-cons{}-lr{}'.format(config['normal'], config['network'], config['pada_cons_w'], config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) print('train_pada') print(config) set_log_config(res_dir) logging.debug('train_pada') # logging.debug(extractor) # logging.debug(classifier) logging.debug(config) if config['models'] == 'PADA': random_layer = None ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens']) elif cdan_random: random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens']) ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens']) random_layer.cuda() else: random_layer = None ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens']) ad_net = ad_net.cuda() optimizer_s = torch.optim.Adam([{ 'params': extractor_s.parameters(), 'lr': config['lr'] }, { 'params': classifier_s.parameters(), 'lr': config['lr'] }]) optimizer_t = torch.optim.Adam([{ 'params': extractor_t.parameters(), 'lr': config['lr'] }, { 'params': classifier_t.parameters(), 'lr': config['lr'] }]) optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr']) def train_stage1(extractor_s, classifier_s, config, epoch): extractor_s.train() classifier_s.train() # STAGE 1: # 在labeled source上训练extractor_s和classifier_s # 训练完成后freeze这两个model iter_source = iter(config['source_train_loader']) len_source_loader = len(config['source_train_loader']) for step in range(1, len_source_loader + 1): data_source, label_source = iter_source.next() if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() optimizer_s.zero_grad() h_s = extractor_s(data_source) h_s = h_s.view(h_s.size(0), -1) source_preds = classifier_s(h_s) cls_loss = nn.CrossEntropyLoss()(source_preds, label_source) cls_loss.backward() optimizer_s.step() def train(extractor_s, classifier_s, extractor_t, classifier_t, ad_net, config, epoch): start_epoch = 0 # extractor_s.train() # classifier_s.train() # ad_net.train() # # STAGE 1: # # 在labeled source上训练extractor_s和classifier_s # # 训练完成后freeze这两个model # iter_source = iter(config['source_train_loader']) # len_source_loader = len(config['source_train_loader']) # for step in range(1, len_source_loader + 1): # data_source, label_source = iter_source.next() # if torch.cuda.is_available(): # data_source, label_source = data_source.cuda(), label_source.cuda() # optimizer_s.zero_grad() # h_s = extractor_s(data_source) # h_s = h_s.view(h_s.size(0), -1) # source_preds = classifier_s(h_s) # cls_loss = nn.CrossEntropyLoss()(source_preds, label_source) # cls_loss.backward() # optimizer_s.step() # for param in extractor_s.parameters(): # param.requires_grad = False # for param in classifier_s.parameters(): # param.requires_grad = False # STAGE 2: # 使用新的extractor和classifier进行DANN训练 # 不同的地方是,每个target 同时使用extractor_s和extractor_t extractor_t.train() classifier_t.train() ad_net.train() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) num_iter = len_source_loader for step in range(1, num_iter + 1): data_source, label_source = iter_source.next() data_target, label_target = iter_target.next() if step % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() data_target, label_target = data_target.cuda( ), label_target.cuda() optimizer_t.zero_grad() optimizer_ad.zero_grad() h_s = extractor_t(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor_t(data_target) h_t = h_t.view(h_t.size(0), -1) source_preds = classifier_t(h_s) cls_loss = nn.CrossEntropyLoss()(source_preds, label_source) softmax_output_s = nn.Softmax(dim=1)(source_preds) target_preds = classifier_t(h_t) softmax_output_t = nn.Softmax(dim=1)(target_preds) if config['target_labeling'] == 1: cls_loss += nn.CrossEntropyLoss()(target_preds, label_target) feature = torch.cat((h_s, h_t), 0) softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0) if epoch > start_epoch: gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 if config['models'] == 'CDAN-E': entropy = loss_func.Entropy(softmax_output) d_loss = loss_func.CDAN( [feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter * (epoch - start_epoch) + step), random_layer) elif config['models'] == 'CDAN': d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer) elif config['models'] == 'PADA': d_loss = loss_func.DANN(feature, ad_net, gamma) else: raise ValueError('Method cannot be recognized.') else: d_loss = 0 # constraints loss h_s_prev = extractor_s(data_source) cons_loss = nn.L1Loss()(h_s, h_s_prev) loss = cls_loss + d_loss + config['pada_cons_w'] * cons_loss loss.backward() optimizer_t.step() if epoch > start_epoch: optimizer_ad.step() if (step) % 20 == 0: print( 'Train Epoch {} closs {:.6f}, dloss {:.6f}, cons_loss {:.6f}, Loss {:.6f}' .format(epoch, cls_loss.item(), d_loss.item(), cons_loss.item(), loss.item())) for epoch in range(1, config['n_epochs'] + 1): train_stage1(extractor_s, classifier_s, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: print('test on source_test_loader') test(extractor_s, classifier_s, config['source_test_loader'], epoch) # print('test on target_test_loader') # accuracy = test(extractor_s, classifier_s, config['target_test_loader'], epoch) extractor_t.load_state_dict(extractor_s.state_dict()) classifier_t.load_state_dict(classifier_s.state_dict()) for param in extractor_s.parameters(): param.requires_grad = False for param in classifier_s.parameters(): param.requires_grad = False for epoch in range(1, config['n_epochs'] + 1): train(extractor_s, classifier_s, extractor_t, classifier_t, ad_net, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') accuracy = test(extractor_t, classifier_t, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: title = config['models'] draw_confusion_matrix(extractor_t, classifier_t, config['target_test_loader'], res_dir, epoch, title) draw_tsne(extractor_t, classifier_t, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True)
def train_cdan_vat(config): extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() xi = 1e-06 ip = 1 eps = 15 vat = VirtualAdversarialPerturbationGenerator(extractor, classifier, xi=xi, eps=eps, ip=ip) cdan_random = config['random_layer'] res_dir = os.path.join(config['res_dir'], 'random{}-bs{}-lr{}'.format(cdan_random, config['batch_size'], config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) print('train_cdan') print(extractor) print(classifier) print(config) set_log_config(res_dir) logging.debug('train_cdan') logging.debug(extractor) logging.debug(classifier) logging.debug(config) if config['models'] == 'DANN': random_layer = None ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens']) elif cdan_random: random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens']) ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens']) random_layer.cuda() else: random_layer = None ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens']) ad_net = ad_net.cuda() optimizer = torch.optim.Adam([ {'params': extractor.parameters(), 'lr': config['lr']}, {'params': classifier.parameters(), 'lr': config['lr']} ]) optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr']) extractor_path = os.path.join(res_dir, "extractor.pth") classifier_path = os.path.join(res_dir, "classifier.pth") adnet_path = os.path.join(res_dir, "adnet.pth") def train(extractor, classifier, ad_net, config, epoch): start_epoch = 0 extractor.train() classifier.train() ad_net.train() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) num_iter = len_source_loader for step in range(1, num_iter + 1): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if step % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda(), label_source.cuda() data_target = data_target.cuda() optimizer.zero_grad() optimizer_ad.zero_grad() h_s = extractor(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor(data_target) h_t = h_t.view(h_t.size(0), -1) source_preds = classifier(h_s) cls_loss = nn.CrossEntropyLoss()(source_preds, label_source) softmax_output_s = nn.Softmax(dim=1)(source_preds) target_preds = classifier(h_t) softmax_output_t = nn.Softmax(dim=1)(target_preds) feature = torch.cat((h_s, h_t), 0) softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0) if epoch > start_epoch: gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 if config['models'] == 'CDAN-E': entropy = loss_func.Entropy(softmax_output) d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer) elif config['models'] == 'CDAN': d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer) elif config['models'] == 'DANN': d_loss = loss_func.DANN(feature, ad_net, gamma) elif config['models'] == 'CDAN_VAT': # entropy = loss_func.Entropy(softmax_output) # d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer) d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer) # vat_loss = loss_func.VAT(vat, data_target, extractor, classifier, target_consistency_criterion) # vat_adv, clean_vat_logits = vat(data_target) # vat_adv_inputs = data_target + vat_adv # adv_vat_features = extractor(vat_adv_inputs) # adv_vat_logits = classifier(adv_vat_features) # target_vat_loss = target_consistency_criterion(adv_vat_logits, clean_vat_logits) # vat_loss = target_vat_loss_weight * target_vat_loss else: raise ValueError('Method cannot be recognized.') else: d_loss = 0 loss = cls_loss + d_loss loss.backward() optimizer.step() vat_adv, clean_vat_logits = vat(data_target) vat_adv_inputs = data_target + vat_adv adv_vat_features = extractor(vat_adv_inputs) adv_vat_logits = classifier(adv_vat_features) target_vat_loss = target_consistency_criterion(adv_vat_logits, clean_vat_logits) vat_loss = target_vat_loss_weight * target_vat_loss vat_loss.backward() # optimizer.step() if epoch > start_epoch: optimizer_ad.step() if (step) % 20 == 0: print('Train Epoch {} closs {:.6f}, dloss {:.6f}, vat_loss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), vat_loss.item(), loss.item())) if config['testonly'] == 0: best_accuracy = 0 best_model_index = -1 for epoch in range(1, config['n_epochs'] + 1): train(extractor, classifier, ad_net, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: print('test on source_test_loader') test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') accuracy = test(extractor, classifier, config['target_test_loader'], epoch) if accuracy > best_accuracy: best_accuracy = accuracy best_model_index = epoch torch.save(extractor.state_dict(), extractor_path) torch.save(classifier.state_dict(), classifier_path) torch.save(ad_net.state_dict(), adnet_path) print('epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.format(epoch, accuracy, best_accuracy, best_model_index)) if epoch % config['VIS_INTERVAL'] == 0: title = config['models'] draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True) # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False) else: if os.path.exists(extractor_path) and os.path.exists(classifier_path) and os.path.exists(adnet_path): extractor.load_state_dict(torch.load(extractor_path)) classifier.load_state_dict(torch.load(classifier_path)) ad_net.load_state_dict(torch.load(adnet_path)) print('Test only mode, model loaded') print('test on source_test_loader') test(extractor, classifier, config['source_test_loader'], -1) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], -1) title = config['models'] draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, -1, title) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, -1, title, separate=True) else: print('no saved model found')
def train_wasserstein(config): # extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) extractor = InceptionV1(num_classes=32) classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) critic = Critic(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() critic = critic.cuda() triplet_type = config['triplet_type'] gamma = config['w_gamma'] weight_wd = config['w_weight'] weight_triplet = config['t_weight'] t_margin = config['t_margin'] t_confidence = config['t_confidence'] k_critic = 3 k_clf = 1 TRIPLET_START_INDEX = 95 if triplet_type == 'none': res_dir = os.path.join( config['res_dir'], 'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'], weight_wd, gamma)) if not os.path.exists(res_dir): os.makedirs(res_dir) extractor_path = os.path.join(res_dir, "extractor.pth") classifier_path = os.path.join(res_dir, "classifier.pth") critic_path = os.path.join(res_dir, "critic.pth") EPOCH_START = 1 TEST_INTERVAL = 10 else: TEST_INTERVAL = 1 w_dir = os.path.join( config['res_dir'], 'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'], weight_wd, gamma)) if not os.path.exists(w_dir): os.makedirs(w_dir) res_dir = os.path.join( w_dir, '{}_t_weight{}_margin{}_confidence{}'.format( triplet_type, weight_triplet, t_margin, t_confidence)) if not os.path.exists(res_dir): os.makedirs(res_dir) extractor_path = os.path.join(w_dir, "extractor.pth") classifier_path = os.path.join(w_dir, "classifier.pth") critic_path = os.path.join(w_dir, "critic.pth") if os.path.exists(extractor_path): extractor.load_state_dict(torch.load(extractor_path)) classifier.load_state_dict(torch.load(classifier_path)) critic.load_state_dict(torch.load(critic_path)) print('load models') EPOCH_START = TRIPLET_START_INDEX else: EPOCH_START = 1 set_log_config(res_dir) print('start epoch {}'.format(EPOCH_START)) print('triplet type {}'.format(triplet_type)) print(config) logging.debug('train_wt') logging.debug(extractor) logging.debug(classifier) logging.debug(critic) logging.debug(config) criterion = torch.nn.CrossEntropyLoss() softmax_layer = nn.Softmax(dim=1) critic_opt = torch.optim.Adam(critic.parameters(), lr=config['lr']) classifier_opt = torch.optim.Adam(classifier.parameters(), lr=config['lr']) feature_opt = torch.optim.Adam(extractor.parameters(), lr=config['lr'] / 10) def train(extractor, classifier, critic, config, epoch): extractor.train() classifier.train() critic.train() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) num_iter = len_source_loader for step in range(1, num_iter): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if step % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() data_target = data_target.cuda() # 1. train critic set_requires_grad(extractor, requires_grad=False) set_requires_grad(classifier, requires_grad=False) set_requires_grad(critic, requires_grad=True) with torch.no_grad(): h_s = extractor(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor(data_target) h_t = h_t.view(h_t.size(0), -1) for j in range(k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + gamma * gp critic_opt.zero_grad() critic_cost.backward() critic_opt.step() if step == 10 and j == 0: print('EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'. format(epoch, wasserstein_distance.item(), (gamma * gp).item(), critic_cost.item())) logging.debug( 'EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'. format(epoch, wasserstein_distance.item(), (gamma * gp).item(), critic_cost.item())) # 2. train feature and class_classifier set_requires_grad(extractor, requires_grad=True) set_requires_grad(classifier, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(k_clf): h_s = extractor(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor(data_target) h_t = h_t.view(h_t.size(0), -1) source_preds = classifier(h_s) clf_loss = criterion(source_preds, label_source) wasserstein_distance = critic(h_s).mean() - critic(h_t).mean() if triplet_type != 'none' and epoch >= TRIPLET_START_INDEX: target_preds = classifier(h_t) target_labels = target_preds.data.max(1)[1] target_logits = softmax_layer(target_preds) if triplet_type == 'all': triplet_index = np.where( target_logits.data.max(1)[0].cpu().numpy() > t_margin)[0] images = torch.cat((h_s, h_t[triplet_index]), 0) labels = torch.cat( (label_source, target_labels[triplet_index]), 0) elif triplet_type == 'src': images = h_s labels = label_source elif triplet_type == 'tgt': triplet_index = np.where( target_logits.data.max(1)[0].cpu().numpy() > t_confidence)[0] images = h_t[triplet_index] labels = target_labels[triplet_index] elif triplet_type == 'sep': triplet_index = np.where( target_logits.data.max(1)[0].cpu().numpy() > t_confidence)[0] images = h_t[triplet_index] labels = target_labels[triplet_index] t_loss_sep, _ = triplet_loss(extractor, { "X": images, "y": labels }, t_confidence) images = h_s labels = label_source t_loss, _ = triplet_loss(extractor, { "X": images, "y": labels }, t_margin) loss = clf_loss + \ weight_wd * wasserstein_distance + \ weight_triplet * t_loss if triplet_type == 'sep': loss += t_loss_sep feature_opt.zero_grad() classifier_opt.zero_grad() loss.backward() feature_opt.step() classifier_opt.step() if step == 10: print( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), weight_triplet * t_loss.item(), loss.item())) logging.debug( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), weight_triplet * t_loss.item(), loss.item())) else: loss = clf_loss + weight_wd * wasserstein_distance feature_opt.zero_grad() classifier_opt.zero_grad() loss.backward() feature_opt.step() classifier_opt.step() if step == 10: print( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), loss.item())) logging.debug( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), loss.item())) # pretrain(model, config, pretrain_epochs=20) for epoch in range(EPOCH_START, config['n_epochs'] + 1): train(extractor, classifier, critic, config, epoch) if epoch % TEST_INTERVAL == 0: # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], epoch) # print('test on target_train_loader') # test(model, config['target_train_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: if triplet_type == 'none': title = '(a) WDGRL' else: title = '(b) TLADA' draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True) # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False) if triplet_type == 'none': torch.save(extractor.state_dict(), extractor_path) torch.save(classifier.state_dict(), classifier_path) torch.save(critic.state_dict(), critic_path)
def train_cdan(config): if config['network'] == 'inceptionv1': extractor = InceptionV1(num_classes=32) elif config['network'] == 'inceptionv1s': extractor = InceptionV1s(num_classes=32) else: extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], bn=config['bn']) classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda() if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() cdan_random = config['random_layer'] res_dir = os.path.join(config['res_dir'], 'slim{}-targetLabel{}-snr{}-snrp{}-lr{}'.format(config['slim'], config['target_labeling'], config['snr'], config['snrp'], config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) print('train_cdan') #print(extractor) #print(classifier) print(config) set_log_config(res_dir) logging.debug('train_cdan') logging.debug(extractor) logging.debug(classifier) logging.debug(config) if config['models'] == 'DANN': random_layer = None ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens']) elif cdan_random: random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens']) ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens']) random_layer.cuda() else: random_layer = None ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens']) ad_net = ad_net.cuda() optimizer = torch.optim.Adam([ {'params': extractor.parameters(), 'lr': config['lr']}, {'params': classifier.parameters(), 'lr': config['lr']} ]) optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr']) extractor_path = os.path.join(res_dir, "extractor.pth") classifier_path = os.path.join(res_dir, "classifier.pth") adnet_path = os.path.join(res_dir, "adnet.pth") def train(extractor, classifier, ad_net, config, epoch): start_epoch = 0 extractor.train() classifier.train() ad_net.train() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) if config['slim'] > 0: iter_target_semi = iter(config['target_train_semi_loader']) len_target_semi_loader = len(config['target_train_semi_loader']) num_iter = len_source_loader for step in range(1, num_iter + 1): data_source, label_source = iter_source.next() data_target, label_target = iter_target.next() if config['slim'] > 0: data_target_semi, label_target_semi = iter_target_semi.next() if step % len_target_semi_loader == 0: iter_target_semi = iter(config['target_train_semi_loader']) if step % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda(), label_source.cuda() data_target, label_target = data_target.cuda(), label_target.cuda() if config['slim'] > 0: data_target_semi, label_target_semi = data_target_semi.cuda(), label_target_semi.cuda() optimizer.zero_grad() optimizer_ad.zero_grad() h_s = extractor(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor(data_target) h_t = h_t.view(h_t.size(0), -1) source_preds = classifier(h_s) cls_loss = nn.CrossEntropyLoss()(source_preds, label_source) softmax_output_s = nn.Softmax(dim=1)(source_preds) target_preds = classifier(h_t) softmax_output_t = nn.Softmax(dim=1)(target_preds) if config['target_labeling'] == 1: cls_loss += nn.CrossEntropyLoss()(target_preds, label_target) feature = torch.cat((h_s, h_t), 0) softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0) if epoch > start_epoch: gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 if config['models'] == 'CDAN-E': entropy = loss_func.Entropy(softmax_output) d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer) elif config['models'] == 'CDAN': d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer) elif config['models'] == 'DANN': d_loss = loss_func.DANN(feature, ad_net, gamma) else: raise ValueError('Method cannot be recognized.') else: d_loss = 0 loss = cls_loss + d_loss err_t_bnm = get_loss_bnm(target_preds) err_s_vat = vat_loss(data_source, source_preds) err_t_vat = vat_loss(data_target, target_preds) loss += 1.0 * err_s_vat + 1.0 * err_t_vat + 1.0 * err_t_bnm if config['slim'] > 0: feature_target_semi = extractor(data_target_semi) feature_target_semi = feature_target_semi.view(feature_target_semi.size(0), -1) preds_target_semi = classifier(feature_target_semi) loss += nn.CrossEntropyLoss()(preds_target_semi, label_target_semi) loss.backward() optimizer.step() if epoch > start_epoch: optimizer_ad.step() if (step) % 100 == 0: print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), loss.item())) if config['testonly'] == 0: best_accuracy = 0 best_model_index = -1 for epoch in range(1, config['n_epochs'] + 1): train(extractor, classifier, ad_net, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') accuracy = test(extractor, classifier, config['target_test_loader'], epoch) if accuracy > best_accuracy: best_accuracy = accuracy best_model_index = epoch torch.save(extractor.state_dict(), extractor_path) torch.save(classifier.state_dict(), classifier_path) torch.save(ad_net.state_dict(), adnet_path) print('epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.format(epoch, accuracy, best_accuracy, best_model_index)) if epoch % config['VIS_INTERVAL'] == 0: title = config['models'] draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True) # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False) else: if os.path.exists(extractor_path) and os.path.exists(classifier_path) and os.path.exists(adnet_path): extractor.load_state_dict(torch.load(extractor_path)) classifier.load_state_dict(torch.load(classifier_path)) ad_net.load_state_dict(torch.load(adnet_path)) print('Test only mode, model loaded') # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], -1) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], -1) title = config['models'] draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, -1, title) # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, -1, title, separate=True) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, -1, title, separate=True) else: print('no saved model found')