def train_cnn(config): if config['network'] == 'inceptionv1': extractor = InceptionV1(num_classes=32, dilation=config['dilation']) elif config['network'] == 'inceptionv1s': extractor = InceptionV1s(num_classes=32, dilation=config['dilation']) else: extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() res_dir = os.path.join( config['res_dir'], 'slim{}-snr{}-snrp{}-lr{}'.format(config['slim'], config['snr'], config['snrp'], config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) set_log_config(res_dir) logging.debug(extractor) logging.debug(classifier) logging.debug(config) criterion = torch.nn.CrossEntropyLoss() optimizer = optim.Adam(list(extractor.parameters()) + list(classifier.parameters()), lr=config['lr']) def train(extractor, classifier, config, epoch): extractor.train() classifier.train() for step, (features, labels) in enumerate(config['source_train_loader']): if torch.cuda.is_available(): features, labels = features.cuda(), labels.cuda() optimizer.zero_grad() # if config['aux_classifier'] == 1: # x1, x2, x3 = extractor(features) # preds = classifier(x1, x2, x3) preds, _ = classifier(extractor(features)) # print('preds {}, labels {}'.format(preds.shape, labels.shape)) # print(preds[0]) # preds_l = F.softmax(preds, dim=1) # print('preds_l {}'.format(preds_l.shape)) # print(preds_l[0]) # print('------') loss = criterion(preds, labels) loss.backward() optimizer.step() for epoch in range(1, config['n_epochs'] + 1): train(extractor, classifier, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: print('test on source_test_loader') test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models']) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=True) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=False)
def train_dann(config): if config['network'] == 'inceptionv1': extractor = InceptionV1(num_classes=32, dilation=config['dilation']) elif config['network'] == 'inceptionv1s': extractor = InceptionV1s(num_classes=32, dilation=config['dilation']) else: extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) critic = Critic2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() critic = critic.cuda() criterion = torch.nn.CrossEntropyLoss() loss_class = torch.nn.CrossEntropyLoss() loss_domain = torch.nn.CrossEntropyLoss() res_dir = os.path.join( config['res_dir'], 'VIS-slim{}-targetLabel{}-mmd{}-bnm{}-vat{}-ent{}-ew{}-bn{}-bs{}-lr{}'. format(config['slim'], config['target_labeling'], config['mmd'], config['bnm'], config['vat'], config['ent'], config['bnm_ew'], config['bn'], config['batch_size'], config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) set_log_config(res_dir) logging.debug('train_dann') logging.debug(extractor) logging.debug(classifier) logging.debug(critic) logging.debug(config) optimizer = optim.Adam([{ 'params': extractor.parameters() }, { 'params': classifier.parameters() }, { 'params': critic.parameters() }], lr=config['lr']) vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda() def dann(input_data, alpha): feature = extractor(input_data) feature = feature.view(feature.size(0), -1) reverse_feature = ReverseLayerF.apply(feature, alpha) class_output, _ = classifier(feature) domain_output = critic(reverse_feature) return class_output, domain_output, feature def train(extractor, classifier, critic, config, epoch): extractor.train() classifier.train() critic.train() gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 mmd_loss = MMD_loss() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) if config['slim'] > 0: iter_target_semi = iter(config['target_train_semi_loader']) len_target_semi_loader = len(config['target_train_semi_loader']) num_iter = len_source_loader for i in range(1, num_iter + 1): data_source, label_source = iter_source.next() data_target, label_target = iter_target.next() if config['slim'] > 0: data_target_semi, label_target_semi = iter_target_semi.next() if i % len_target_semi_loader == 0: iter_target_semi = iter(config['target_train_semi_loader']) if i % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() data_target, label_target = data_target.cuda( ), label_target.cuda() if config['slim'] > 0: data_target_semi, label_target_semi = data_target_semi.cuda( ), label_target_semi.cuda() optimizer.zero_grad() class_output_s, domain_output, feature_s = dann( input_data=data_source, alpha=gamma) # print('domain_output {}'.format(domain_output.size())) err_s_label = loss_class(class_output_s, label_source) domain_label = torch.zeros(data_source.size(0)).long().cuda() err_s_domain = loss_domain(domain_output, domain_label) # Training model using target data domain_label = torch.ones(data_target.size(0)).long().cuda() class_output_t, domain_output, feature_t = dann( input_data=data_target, alpha=gamma) #class_output_t, domain_output, _ = dann(input_data=data_target, alpha=0.5) err_t_domain = loss_domain(domain_output, domain_label) err = err_s_label + err_s_domain + err_t_domain # if config['target_labeling'] == 1: # err_t_class_healthy = nn.CrossEntropyLoss()(class_output_t, label_target) # err += err_t_class_healthy # if i % 100 == 0: # print('err_t_class_healthy {:.2f}'.format(err_t_class_healthy.item())) if config['mmd'] == 1: #err += gamma * mmd_linear(feature_s, feature_t) err += config['bnm_ew'] * mmd_loss(feature_s, feature_t) if config['bnm'] == 1 and epoch >= config['startiter']: err_t_bnm = config['bnm_ew'] * get_loss_bnm(class_output_t) err += err_t_bnm if i == 1: print('epoch {}, loss_t_bnm {:.2f}'.format( epoch, err_t_bnm.item())) if config['ent'] == 1 and epoch >= config['startiter']: err_t_ent = config['bnm_ew'] * get_loss_entropy(class_output_t) err += err_t_ent if i == 1: print('epoch {}, loss_t_ent {:.2f}'.format( epoch, err_t_ent.item())) if config['vat'] == 1 and epoch >= config['startiter']: err_t_vat = config['bnm_ew'] * vat_loss( data_target, class_output_t) err += err_t_vat if i == 1: print('epoch {}, loss_t_vat {:.2f}'.format( epoch, err_t_vat.item())) if config['slim'] > 0: feature_target_semi = extractor(data_target_semi) feature_target_semi = feature_target_semi.view( feature_target_semi.size(0), -1) preds_target_semi, _ = classifier(feature_target_semi) err_t_class_semi = loss_class(preds_target_semi, label_target_semi) err += err_t_class_semi if i == 1: print('epoch {}, err_t_class_semi {:.2f}'.format( epoch, err_t_class_semi.item())) if i == 1: print( 'epoch {}, err_s_label {:.2f}, err_s_domain {:.2f}, err_t_domain {:.2f}, total err {:.2f}' .format(epoch, err_s_label.item(), err_s_domain.item(), err_t_domain.item(), err.item())) err.backward() optimizer.step() for epoch in range(1, config['n_epochs'] + 1): train(extractor, classifier, critic, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: title = 'DANN' if config['bnm'] == 1 and config['vat'] == 1: title = '(b) Proposed' elif config['bnm'] == 1: title = 'BNM' elif config['vat'] == 1: title = 'VADA' elif config['mmd'] == 1: title = 'DCTLN' elif config['ent'] == 1: title = 'EntMin' # draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models']) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
def train_dann_vat(config): extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) critic = Critic2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() critic = critic.cuda() res_dir = os.path.join(config['res_dir'], 'BNM-slim{}-targetLabel{}-sw{}-tw{}-ew{}-lr{}'.format(config['slim'], config['target_labeling'], sw, tw, ew, config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) criterion = torch.nn.CrossEntropyLoss() loss_class = torch.nn.CrossEntropyLoss() loss_domain = torch.nn.CrossEntropyLoss() cent = ConditionalEntropyLoss().cuda() vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda() print('train_dann_vat') # print(extractor) # print(classifier) # print(critic) print(config) set_log_config(res_dir) logging.debug('train_dann_vat') logging.debug(extractor) logging.debug(classifier) logging.debug(critic) logging.debug(config) optimizer = optim.Adam([{'params': extractor.parameters()}, {'params': classifier.parameters()}, {'params': critic.parameters()}], lr=config['lr']) def dann(input_data, alpha): feature = extractor(input_data) feature = feature.view(feature.size(0), -1) reverse_feature = ReverseLayerF.apply(feature, alpha) class_output = classifier(feature) domain_output = critic(reverse_feature) return class_output, domain_output, feature def train(extractor, classifier, critic, config, epoch): extractor.train() classifier.train() critic.train() gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) if config['slim'] > 0: iter_target_semi = iter(config['target_train_semi_loader']) len_target_semi_loader = len(config['target_train_semi_loader']) num_iter = len_source_loader for i in range(1, num_iter+1): data_source, label_source = iter_source.next() data_target, label_target = iter_target.next() if config['slim'] > 0: data_target_semi, label_target_semi = iter_target_semi.next() if i % len_target_semi_loader == 0: iter_target_semi = iter(config['target_train_semi_loader']) if i % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda(), label_source.cuda() data_target, label_target = data_target.cuda(), label_target.cuda() if config['slim'] > 0: data_target_semi, label_target_semi = data_target_semi.cuda(), label_target_semi.cuda() optimizer.zero_grad() class_output_s, domain_output_s, _ = dann(input_data=data_source, alpha=gamma) err_s_class = loss_class(class_output_s, label_source) domain_label_s = torch.zeros(data_source.size(0)).long().cuda() err_s_domain = loss_domain(domain_output_s, domain_label_s) # Training model using target data class_output_t, domain_output_t, _ = dann(input_data=data_target, alpha=gamma) domain_label_t = torch.ones(data_target.size(0)).long().cuda() err_t_domain = loss_domain(domain_output_t, domain_label_t) # target entropy loss # err_t_entropy = get_loss_entropy(class_output_t) err_t_entropy = get_loss_bnm(class_output_t) # virtual adversarial loss. err_s_vat = vat_loss(data_source, class_output_s) err_t_vat = vat_loss(data_target, class_output_t) err_domain = 1 * (err_s_domain + err_t_domain) err_all = ( dw * err_domain + cw * err_s_class + sw * err_s_vat + tw * err_t_vat + ew * err_t_entropy ) if config['slim'] > 0: feature_target_semi = extractor(data_target_semi) feature_target_semi = feature_target_semi.view(feature_target_semi.size(0), -1) preds_target_semi = classifier(feature_target_semi) err_t_class_semi = loss_class(preds_target_semi, label_target_semi) err_all += err_t_class_semi # if i % 100 == 0: # print('epoch {}, err_t_class_semi {:.2f}'.format(epoch, err_t_class_semi.item())) # if config['target_labeling'] == 1: # err_t_class_healthy = nn.CrossEntropyLoss()(class_output_t, label_target) # err_all += err_t_class_healthy # if i % 100 == 0: # print('err_t_class_healthy {:.2f}'.format(err_t_class_healthy.item())) # if i % 100 == 0: # print('err_s_class {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, err_s_vat {:.2f}, err_t_vat {:.2f}, err_t_entropy {:.2f}, err_all {:.2f}'.format(err_s_class.item(), # err_s_domain.item(), # gamma, # err_t_domain.item(), # err_s_vat.item(), # err_t_vat.item(), # err_t_entropy.item(), # err_all.item())) err_all.backward() optimizer.step() for epoch in range(1, config['n_epochs'] + 1): print('epoch {}'.format(epoch)) print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) train(extractor, classifier, critic, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: print('test on source_test_loader') test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models']) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=True)
def train_ddc(config): if config['network'] == 'inceptionv1': extractor = InceptionV1(num_classes=32, dilation=config['dilation']) elif config['network'] == 'inceptionv1s': extractor = InceptionV1s(num_classes=32, dilation=config['dilation']) else: extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() res_dir = os.path.join( config['res_dir'], 'slim{}-targetLabel{}-normal{}-{}-dilation{}-lr{}-mmdgamma{}'.format( config['normal'], config['slim'], config['target_labeling'], config['network'], config['dilation'], config['lr'], config['mmd_gamma'])) if not os.path.exists(res_dir): os.makedirs(res_dir) criterion = torch.nn.CrossEntropyLoss() set_log_config(res_dir) logging.debug('train_ddc') logging.debug(extractor) logging.debug(classifier) logging.debug(config) optimizer = optim.Adam(list(extractor.parameters()) + list(classifier.parameters()), lr=config['lr']) def train(extractor, classifier, config, epoch): extractor.train() classifier.train() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) if config['slim'] > 0: iter_target_semi = iter(config['target_train_semi_loader']) len_target_semi_loader = len(config['target_train_semi_loader']) num_iter = len_source_loader for i in range(1, num_iter + 1): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if config['slim'] > 0: data_target_semi, label_target_semi = iter_target_semi.next() if i % len_target_semi_loader == 0: iter_target_semi = iter(config['target_train_semi_loader']) if i % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() data_target = data_target.cuda() if config['slim'] > 0: data_target_semi, label_target_semi = data_target_semi.cuda( ), label_target_semi.cuda() optimizer.zero_grad() source = extractor(data_source) source = source.view(source.size(0), -1) target = extractor(data_target) target = target.view(target.size(0), -1) preds, _ = classifier(source) loss_cls = criterion(preds, label_source) loss_mmd = mmd_linear(source, target) #gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 loss = loss_cls + config['mmd_gamma'] * loss_mmd if config['slim'] > 0: feature_target_semi = extractor(data_target_semi) feature_target_semi = feature_target_semi.view( feature_target_semi.size(0), -1) preds_target_semi, _ = classifier(feature_target_semi) loss += criterion(preds_target_semi, label_target_semi) if i % 50 == 0: print( 'loss_cls {}, loss_mmd {}, gamma {}, total loss {}'.format( loss_cls.item(), loss_mmd.item(), config['mmd_gamma'], loss.item())) loss.backward() optimizer.step() for epoch in range(1, config['n_epochs'] + 1): train(extractor, classifier, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: print('test on source_test_loader') test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models']) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=True) draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=False)
def train_wasserstein(config): extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) # extractor = InceptionV1(num_classes=32) classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) critic = Critic(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() critic = critic.cuda() triplet_type = config['triplet_type'] gamma = config['w_gamma'] weight_wd = config['w_weight'] weight_triplet = config['t_weight'] t_margin = config['t_margin'] t_confidence = config['t_confidence'] k_critic = 3 k_clf = 1 TRIPLET_START_INDEX = 95 if triplet_type == 'none': res_dir = os.path.join( config['res_dir'], 'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'], weight_wd, gamma)) if not os.path.exists(res_dir): os.makedirs(res_dir) extractor_path = os.path.join(res_dir, "extractor.pth") classifier_path = os.path.join(res_dir, "classifier.pth") critic_path = os.path.join(res_dir, "critic.pth") EPOCH_START = 1 TEST_INTERVAL = 10 else: TEST_INTERVAL = 1 w_dir = os.path.join( config['res_dir'], 'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'], weight_wd, gamma)) if not os.path.exists(w_dir): os.makedirs(w_dir) res_dir = os.path.join( w_dir, '{}_t_weight{}_margin{}_confidence{}'.format( triplet_type, weight_triplet, t_margin, t_confidence)) if not os.path.exists(res_dir): os.makedirs(res_dir) extractor_path = os.path.join(w_dir, "extractor.pth") classifier_path = os.path.join(w_dir, "classifier.pth") critic_path = os.path.join(w_dir, "critic.pth") if os.path.exists(extractor_path): extractor.load_state_dict(torch.load(extractor_path)) classifier.load_state_dict(torch.load(classifier_path)) critic.load_state_dict(torch.load(critic_path)) print('load models') EPOCH_START = TRIPLET_START_INDEX else: EPOCH_START = 1 set_log_config(res_dir) print('start epoch {}'.format(EPOCH_START)) print('triplet type {}'.format(triplet_type)) print(config) logging.debug('train_wt') logging.debug(extractor) logging.debug(classifier) logging.debug(critic) logging.debug(config) criterion = torch.nn.CrossEntropyLoss() softmax_layer = nn.Softmax(dim=1) critic_opt = torch.optim.Adam(critic.parameters(), lr=config['lr']) classifier_opt = torch.optim.Adam(classifier.parameters(), lr=config['lr']) feature_opt = torch.optim.Adam(extractor.parameters(), lr=config['lr'] / 10) def train(extractor, classifier, critic, config, epoch): extractor.train() classifier.train() critic.train() iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) num_iter = len_source_loader for step in range(1, num_iter): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if step % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() data_target = data_target.cuda() # 1. train critic set_requires_grad(extractor, requires_grad=False) set_requires_grad(classifier, requires_grad=False) set_requires_grad(critic, requires_grad=True) with torch.no_grad(): h_s = extractor(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor(data_target) h_t = h_t.view(h_t.size(0), -1) for j in range(k_critic): gp = gradient_penalty(critic, h_s, h_t) critic_s = critic(h_s) critic_t = critic(h_t) wasserstein_distance = critic_s.mean() - critic_t.mean() critic_cost = -wasserstein_distance + gamma * gp critic_opt.zero_grad() critic_cost.backward() critic_opt.step() if step == 10 and j == 0: print('EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'. format(epoch, wasserstein_distance.item(), (gamma * gp).item(), critic_cost.item())) logging.debug( 'EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'. format(epoch, wasserstein_distance.item(), (gamma * gp).item(), critic_cost.item())) # 2. train feature and class_classifier set_requires_grad(extractor, requires_grad=True) set_requires_grad(classifier, requires_grad=True) set_requires_grad(critic, requires_grad=False) for _ in range(k_clf): h_s = extractor(data_source) h_s = h_s.view(h_s.size(0), -1) h_t = extractor(data_target) h_t = h_t.view(h_t.size(0), -1) source_preds, _ = classifier(h_s) clf_loss = criterion(source_preds, label_source) wasserstein_distance = critic(h_s).mean() - critic(h_t).mean() if triplet_type != 'none' and epoch >= TRIPLET_START_INDEX: target_preds, _ = classifier(h_t) target_labels = target_preds.data.max(1)[1] target_logits = softmax_layer(target_preds) if triplet_type == 'all': triplet_index = np.where( target_logits.data.max(1)[0].cpu().numpy() > t_margin)[0] images = torch.cat((h_s, h_t[triplet_index]), 0) labels = torch.cat( (label_source, target_labels[triplet_index]), 0) elif triplet_type == 'src': images = h_s labels = label_source elif triplet_type == 'tgt': triplet_index = np.where( target_logits.data.max(1)[0].cpu().numpy() > t_confidence)[0] images = h_t[triplet_index] labels = target_labels[triplet_index] elif triplet_type == 'sep': triplet_index = np.where( target_logits.data.max(1)[0].cpu().numpy() > t_confidence)[0] images = h_t[triplet_index] labels = target_labels[triplet_index] t_loss_sep, _ = triplet_loss(extractor, { "X": images, "y": labels }, t_confidence) images = h_s labels = label_source t_loss, _ = triplet_loss(extractor, { "X": images, "y": labels }, t_margin) loss = clf_loss + \ weight_wd * wasserstein_distance + \ weight_triplet * t_loss if triplet_type == 'sep': loss += t_loss_sep feature_opt.zero_grad() classifier_opt.zero_grad() loss.backward() feature_opt.step() classifier_opt.step() if step == 10: print( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), weight_triplet * t_loss.item(), loss.item())) logging.debug( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), weight_triplet * t_loss.item(), loss.item())) else: loss = clf_loss + weight_wd * wasserstein_distance feature_opt.zero_grad() classifier_opt.zero_grad() loss.backward() feature_opt.step() classifier_opt.step() if step == 10: print( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), loss.item())) logging.debug( 'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, loss {}' .format(epoch, clf_loss.item(), weight_wd * wasserstein_distance.item(), loss.item())) # pretrain(model, config, pretrain_epochs=20) for epoch in range(EPOCH_START, config['n_epochs'] + 1): train(extractor, classifier, critic, config, epoch) if epoch % TEST_INTERVAL == 0: # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], epoch) # print('test on target_train_loader') # test(model, config['target_train_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: if triplet_type == 'none': title = '(a) WDGRL' else: title = '(b) TLADA' draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True) # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False) if triplet_type == 'none': torch.save(extractor.state_dict(), extractor_path) torch.save(classifier.state_dict(), classifier_path) torch.save(critic.state_dict(), critic_path)
def train_dctln(config): if config['network'] == 'inceptionv1': extractor = InceptionV1(num_classes=32, dilation=config['dilation']) elif config['network'] == 'inceptionv1s': extractor = InceptionV1s(num_classes=32, dilation=config['dilation']) else: extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class']) critic = Critic2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens']) if torch.cuda.is_available(): extractor = extractor.cuda() classifier = classifier.cuda() critic = critic.cuda() criterion = torch.nn.CrossEntropyLoss() loss_class = torch.nn.CrossEntropyLoss() loss_domain = torch.nn.CrossEntropyLoss() res_dir = os.path.join( config['res_dir'], 'slim{}-snr{}-lr{}'.format(config['slim'], config['snr'], config['lr'])) if not os.path.exists(res_dir): os.makedirs(res_dir) set_log_config(res_dir) logging.debug('train_dann') logging.debug(extractor) logging.debug(classifier) logging.debug(critic) logging.debug(config) optimizer = optim.Adam([{ 'params': extractor.parameters() }, { 'params': classifier.parameters() }, { 'params': critic.parameters() }], lr=config['lr']) def dann(input_data, alpha): feature = extractor(input_data) feature = feature.view(feature.size(0), -1) reverse_feature = ReverseLayerF.apply(feature, alpha) class_output, _ = classifier(feature) domain_output = critic(reverse_feature) return class_output, domain_output, feature def train(extractor, classifier, critic, config, epoch): extractor.train() classifier.train() critic.train() gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1 iter_source = iter(config['source_train_loader']) iter_target = iter(config['target_train_loader']) len_source_loader = len(config['source_train_loader']) len_target_loader = len(config['target_train_loader']) if config['slim'] > 0: iter_target_semi = iter(config['target_train_semi_loader']) len_target_semi_loader = len(config['target_train_semi_loader']) num_iter = len_source_loader for i in range(1, num_iter + 1): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if config['slim'] > 0: data_target_semi, label_target_semi = iter_target_semi.next() if i % len_target_loader == 0: iter_target = iter(config['target_train_loader']) if config['slim'] > 0: if i % len_target_semi_loader == 0: iter_target_semi = iter(config['target_train_semi_loader']) if torch.cuda.is_available(): data_source, label_source = data_source.cuda( ), label_source.cuda() data_target = data_target.cuda() if config['slim'] > 0: data_target_semi, label_target_semi = data_target_semi.cuda( ), label_target_semi.cuda() optimizer.zero_grad() source = extractor(data_source) source = source.view(source.size(0), -1) target = extractor(data_target) target = target.view(target.size(0), -1) # loss_mmd = mmd_linear(source, target) class_output_s, domain_output, _ = dann(input_data=data_source, alpha=gamma) # print('domain_output {}'.format(domain_output.size())) err_s_label = loss_class(class_output_s, label_source) domain_label = torch.zeros(data_source.size(0)).long().cuda() err_s_domain = loss_domain(domain_output, domain_label) # Training model using target data domain_label = torch.ones(data_target.size(0)).long().cuda() class_output_t, domain_output, _ = dann(input_data=data_target, alpha=gamma) err_t_domain = loss_domain(domain_output, domain_label) # err = 1.0*err_s_label + err_s_domain + err_t_domain + 0*loss_mmd + err_t_label err = 1.0 * err_s_label + err_s_domain + err_t_domain if config['slim'] > 0: class_output_semi_t, _, _ = dann(input_data=data_target_semi, alpha=gamma) err_t_label = loss_class(class_output_semi_t, label_target_semi) err += err_t_label # if i % 200 == 0: # # print('err_s_label {}, err_s_domain {}, gamma {}, err_t_domain {}, loss_mmd {}, total err {}'.format(err_s_label.item(), err_s_domain.item(), gamma, err_t_domain.item(), loss_mmd.item(), err.item())) # print('err_s_label {:.2f}, err_t_label {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, total err {:.2f}'.format(err_s_label.item(), err_t_label.item(), err_s_domain.item(), gamma, err_t_domain.item(), err.item())) err.backward() optimizer.step() for epoch in range(1, config['n_epochs'] + 1): train(extractor, classifier, critic, config, epoch) if epoch % config['TEST_INTERVAL'] == 0: # print('test on source_test_loader') # test(extractor, classifier, config['source_test_loader'], epoch) print('test on target_test_loader') test(extractor, classifier, config['target_test_loader'], epoch) if epoch % config['VIS_INTERVAL'] == 0: draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models']) draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=True)