def __init__(self, args, batch_size=64, source='mnist', target='usps', learning_rate=0.0002, interval=100, optimizer='adam', num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10): self.batch_size = batch_size self.source = source self.target = target self.num_k = num_k self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.all_use = all_use self.lambda_1 = args.lambda_1 self.lambda_2 = args.lambda_2 if self.source == 'svhn': self.scale = True else: self.scale = False print('dataset loading') self.datasets, self.dataset_test = dataset_read(source, target, self.batch_size, scale=self.scale, all_use=self.all_use) print('load finished!') self.G = Generator(source=source, target=target) self.C1 = Classifier(source=source, target=target) self.C2 = Classifier(source=source, target=target) if args.eval_only: self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, self.checkpoint_dir, args.resume_epoch)) self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.cuda() self.C1.cuda() self.C2.cuda() self.interval = interval self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.lr = learning_rate
def __init__(self, args): super().__init__() self.opt_losses = args.opt_losses self.check_list = [ 'disc_ds_di_C_D', 'disc_ds_di_G', 'ring', 'confusion_G' ] assert all([ol in self.check_list for ol in self.opt_losses]), 'Check loss entries' opt_losses_str = ",".join(map(str, self.opt_losses)) timestring = strftime("%Y-%m-%d_%H-%M-%S", gmtime()) + "_{}_optloss={}_src={}".format( args.exp_name, opt_losses_str, args.source) self.logdir = os.path.join('./logs', timestring) self.logger = SummaryWriter(log_dir=self.logdir) self.device = torch.device("cuda" if args.use_cuda else "cpu") self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]), args.batch_size, axis=0) self.trg_domain_code = np.repeat(np.array([[*([0]), *([1])]]), args.batch_size, axis=0) self.src_domain_code = torch.FloatTensor(self.src_domain_code).to( self.device) self.trg_domain_code = torch.FloatTensor(self.trg_domain_code).to( self.device) self.source = args.source self.target = args.target self.num_k = args.num_k self.checkpoint_dir = args.checkpoint_dir self.save_epoch = args.save_epoch self.use_abs_diff = args.use_abs_diff self.mi_k = 1 self.delta = 0.01 self.mi_coeff = 0.0001 self.interval = 10 # write on tb every self.batch_size = args.batch_size self.which_opt = 'adam' self.lr = args.lr self.scale = 32 self.global_step = 0 print('Loading datasets') self.dataset_train, self.dataset_test = dataset_read( args.data_dir, self.source, self.target, self.batch_size, self.scale) print('Done!') self.total_batches = { 'train': self.get_dataset_size('train'), 'test': self.get_dataset_size('test') } self.G = Generator(source=self.source, target=self.target) self.FD = Feature_Discriminator() self.R = Reconstructor() self.MI = Mine() self.C = nn.ModuleDict({ 'ds': Classifier(source=self.source, target=self.target), 'di': Classifier(source=self.source, target=self.target), 'ci': Classifier(source=self.source, target=self.target) }) self.D = nn.ModuleDict({ 'ds': Disentangler(), 'di': Disentangler(), 'ci': Disentangler() }) # All modules in the same dict self.components = nn.ModuleDict({ 'G': self.G, 'FD': self.FD, 'R': self.R, 'MI': self.MI }) self.xent_loss = nn.CrossEntropyLoss().to(self.device) self.adv_loss = nn.BCEWithLogitsLoss().to(self.device) self.ring_loss = RingLoss(type='auto', loss_weight=1.0).to(self.device) self.set_optimizer(lr=self.lr) self.to_device()
class Solver(object): def __init__(self, args, batch_size=64, source='mnist', target='usps', learning_rate=0.0002, interval=100, optimizer='adam' , num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10): self.batch_size = batch_size self.source = source self.target = target self.num_k = num_k self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.all_use = all_use self.lambda_1 = args.lambda_1 self.lambda_2 = args.lambda_2 if self.source == 'svhn': self.scale = True else: self.scale = False print('dataset loading') self.datasets, self.dataset_test = dataset_read(source, target, self.batch_size, scale=self.scale, all_use=self.all_use) print('load finished!') self.G = Generator(source=source, target=target) self.C1 = Classifier(source=source, target=target) self.C2 = Classifier(source=source, target=target) if args.eval_only: self.G.torch.load( '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.torch.load( '%s/%s_to_%s_model_epoch%s_G.pt' % ( self.checkpoint_dir, self.source, self.target, self.checkpoint_dir, args.resume_epoch)) self.G.torch.load( '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.cuda() self.C1.cuda() self.C2.cuda() self.interval = interval self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.lr = learning_rate def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9): if which_opt == 'momentum': self.opt_g = optim.SGD(self.G.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_c1 = optim.SGD(self.C1.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_c2 = optim.SGD(self.C2.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) if which_opt == 'adam': self.opt_g = optim.Adam(self.G.parameters(), lr=lr, weight_decay=0.0005) self.opt_c1 = optim.Adam(self.C1.parameters(), lr=lr, weight_decay=0.0005) self.opt_c2 = optim.Adam(self.C2.parameters(), lr=lr, weight_decay=0.0005) def reset_grad(self): self.opt_g.zero_grad() self.opt_c1.zero_grad() self.opt_c2.zero_grad() def ent(self, output): return - torch.mean(output * torch.log(output + 1e-6)) def discrepancy(self, out1, out2): return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2))) def train(self, epoch, record_file=None): criterion = nn.CrossEntropyLoss().cuda() # initialze a L1 loss for distribution alignment criterionConsistency = nn.L1Loss().cuda() self.C1.train() self.C2.train() torch.cuda.manual_seed(1) Tensor = torch.cuda.FloatTensor for batch_idx, data in enumerate(self.datasets): img_t = data['T'] img_s = data['S'] label_s = data['S_label'] if img_s.size()[0] < self.batch_size or img_t.size()[0] < self.batch_size: break img_s = img_s.cuda() img_t = img_t.cuda() label_s = Variable(label_s.long().cuda()) # for usps and mnist (source) z = Variable(Tensor(np.random.normal(0,1, (2048, 48)))) # for svhn (source) #z = Variable(Tensor(np.random.normal(0,1, (8192, 128)))) img_s = Variable(img_s) img_t = Variable(img_t) self.reset_grad() feat_s = self.G(img_s) output_s1 = self.C1(feat_s) output_s2 = self.C2(feat_s) # for usps and mnist (source) feat_s_kl = feat_s.view(-1,48) # for svhn (source) #feat_s_kl = feat_s.view(-1,128) loss_kld = F.kl_div(F.log_softmax(feat_s_kl), F.softmax(z)) loss_s1 = criterion(output_s1, label_s) loss_s2 = criterion(output_s2, label_s) loss_s = loss_s1 + loss_s2 + self.lambda_1 * loss_kld loss_s.backward() self.opt_g.step() self.opt_c1.step() self.opt_c2.step() self.reset_grad() feat_s = self.G(img_s) output_s1 = self.C1(feat_s) output_s2 = self.C2(feat_s) feat_t = self.G(img_t) output_t1 = self.C1(feat_t) output_t2 = self.C2(feat_t) # for usps and mnist (source) feat_s_kl = feat_s.view(-1,48) # for svhn (source) #feat_s_kl = feat_s.view(-1,128) loss_kld = F.kl_div(F.log_softmax(feat_s_kl), F.softmax(z)) loss_s1 = criterion(output_s1, label_s) loss_s2 = criterion(output_s2, label_s) loss_s = loss_s1 + loss_s2 + self.lambda_1 *loss_kld loss_dis = self.discrepancy(output_t1, output_t2) loss = loss_s - loss_dis loss.backward() self.opt_c1.step() self.opt_c2.step() self.reset_grad() for i in range(self.num_k): feat_t = self.G(img_t) output_t1 = self.C1(feat_t) output_t2 = self.C2(feat_t) # get x_rt feat_t_recon = self.G(img_t, is_deconv=True) feat_z_recon = self.G.decode(z) # distribution alignment loss loss_dal = criterionConsistency(feat_t_recon, feat_z_recon) #updated loss function loss_dis = self.discrepancy(output_t1, output_t2) + self.lambda_2 *loss_dal loss_dis.backward() self.opt_g.step() self.reset_grad() if batch_idx > 500: return batch_idx if batch_idx % self.interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\t Loss2: {:.6f}\t Discrepancy: {:.6f}'.format( epoch, batch_idx, 100, 100. * batch_idx / 70000, loss_s1.item(), loss_s2.item(), loss_dis.item())) if record_file: record = open(record_file, 'a') record.write('%s %s %s\n' % (loss_dis.item(), loss_s1.item(), loss_s2.item())) record.close() torch.save(self.G, '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) return batch_idx def test(self, epoch, record_file=None, save_model=False): self.G.eval() self.C1.eval() self.C2.eval() test_loss = 0 correct1 = 0 correct2 = 0 correct3 = 0 size = 0 for batch_idx, data in enumerate(self.dataset_test): img = data['T'] label = data['T_label'] img, label = img.cuda(), label.long().cuda() img, label = Variable(img, volatile=True), Variable(label) feat = self.G(img) output1 = self.C1(feat) output2 = self.C2(feat) test_loss += F.nll_loss(output1, label).item() output_ensemble = output1 + output2 pred1 = output1.data.max(1)[1] pred2 = output2.data.max(1)[1] pred_ensemble = output_ensemble.data.max(1)[1] k = label.data.size()[0] correct1 += pred1.eq(label.data).cpu().sum() correct2 += pred2.eq(label.data).cpu().sum() correct3 += pred_ensemble.eq(label.data).cpu().sum() size += k test_loss = test_loss / size print( '\nTest set: Average loss: {:.4f}, Accuracy C1: {}/{} ({:.0f}%) Accuracy C2: {}/{} ({:.0f}%) Accuracy Ensemble: {}/{} ({:.0f}%) \n'.format( test_loss, correct1, size, 100. * correct1 / size, correct2, size, 100. * correct2 / size, correct3, size, 100. * correct3 / size)) if save_model and epoch % self.save_epoch == 0: torch.save(self.G, '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) torch.save(self.C1, '%s/%s_to_%s_model_epoch%s_C1.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) torch.save(self.C2, '%s/%s_to_%s_model_epoch%s_C2.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) if record_file: record = open(record_file, 'a') print('recording %s', record_file) record.write('%s %s %s\n' % (float(correct1) / size, float(correct2) / size, float(correct3) / size)) record.close()
def __init__(self, args, batch_size=64, source='svhn', target='mnist', learning_rate=0.0002, interval=1, optimizer='adam', num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10): timestring = strftime("%Y-%m-%d_%H-%M-%S", gmtime()) + "_%s" % args.exp_name self.logdir = os.path.join('./logs', timestring) self.logger = SummaryWriter(log_dir=self.logdir) self.device = torch.device("cuda" if args.use_cuda else "cpu") self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]), batch_size, axis=0) self.trg_domain_code = np.repeat(np.array([[*([0]), *([1])]]), batch_size, axis=0) self.src_domain_code = torch.FloatTensor(self.src_domain_code).to( self.device) self.trg_domain_code = torch.FloatTensor(self.trg_domain_code).to( self.device) self.source = source self.target = target self.num_k = num_k self.mi_k = 1 self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.all_use = all_use self.delta = 0.01 self.mi_coeff = 0.0001 self.interval = interval self.batch_size = batch_size self.lr = learning_rate self.scale = False print('dataset loading') self.datasets, self.dataset_test = dataset_read(source, target, self.batch_size, scale=self.scale, all_use=self.all_use) print('load finished!') self.G = Generator(source=source, target=target) self.FD = Feature_Discriminator() self.R = Reconstructor() self.MI = Mine() self.C = nn.ModuleDict({ 'ds': Classifier(source=source, target=target), 'di': Classifier(source=source, target=target), 'ci': Classifier(source=source, target=target) }) self.D = nn.ModuleDict({ 'ds': Disentangler(), 'di': Disentangler(), 'ci': Disentangler() }) # All modules in the same dict self.modules = nn.ModuleDict({ 'G': self.G, 'FD': self.FD, 'R': self.R, 'MI': self.MI }) if args.eval_only: self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.xent_loss = nn.CrossEntropyLoss().cuda() self.adv_loss = nn.BCEWithLogitsLoss().cuda() self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.to_device()
def __init__(self, args, batch_size=64, source='source', target='target', learning_rate=0.0002, interval=100, optimizer='adam', num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10, leave_one_num=-1, model_name=''): self.args = args self.batch_size = batch_size self.source = source self.target = target self.num_k = num_k self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.leave_one_num = leave_one_num print('dataset loading') download() self.data_train, self.data_val, self.data_test = dataset_read( source, target, self.batch_size, is_resize=args.is_resize, leave_one_num=self.leave_one_num, dataset=args.dataset, sensor_num=args.sensor_num) print('load finished!') self.G = Generator(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset, sensor_num=args.sensor_num) self.LC = Classifier(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset) self.DC = DomainClassifier(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset) if args.eval_only: self.data_val = self.data_test self.G = torch.load(r'checkpoint_DANN/best_model_G' + model_name + '.pt') self.LC = torch.load(r'checkpoint_DANN/best_model_C1' + model_name + '.pt') self.DC = torch.load(r'checkpoint_DANN/best_model_C2' + model_name + '.pt') self.G.cuda() self.LC.cuda() self.DC.cuda() self.interval = interval self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.lr = learning_rate
class SolverDANN(object): def __init__(self, args, batch_size=64, source='source', target='target', learning_rate=0.0002, interval=100, optimizer='adam', num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10, leave_one_num=-1, model_name=''): self.args = args self.batch_size = batch_size self.source = source self.target = target self.num_k = num_k self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.leave_one_num = leave_one_num print('dataset loading') download() self.data_train, self.data_val, self.data_test = dataset_read( source, target, self.batch_size, is_resize=args.is_resize, leave_one_num=self.leave_one_num, dataset=args.dataset, sensor_num=args.sensor_num) print('load finished!') self.G = Generator(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset, sensor_num=args.sensor_num) self.LC = Classifier(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset) self.DC = DomainClassifier(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset) if args.eval_only: self.data_val = self.data_test self.G = torch.load(r'checkpoint_DANN/best_model_G' + model_name + '.pt') self.LC = torch.load(r'checkpoint_DANN/best_model_C1' + model_name + '.pt') self.DC = torch.load(r'checkpoint_DANN/best_model_C2' + model_name + '.pt') self.G.cuda() self.LC.cuda() self.DC.cuda() self.interval = interval self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.lr = learning_rate def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9): if which_opt == 'momentum': self.opt_g = optim.SGD(self.G.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_lc = optim.SGD(self.LC.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_dc = optim.SGD(self.DC.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) if which_opt == 'adam': self.opt_g = optim.Adam(self.G.parameters(), lr=lr, weight_decay=0.0005) self.opt_lc = optim.Adam(self.LC.parameters(), lr=lr, weight_decay=0.0005) self.opt_dc = optim.Adam(self.DC.parameters(), lr=lr, weight_decay=0.0005) def reset_grad(self): self.opt_g.zero_grad() self.opt_lc.zero_grad() self.opt_dc.zero_grad() def ent(self, output): return -torch.mean(output * torch.log(output + 1e-6)) def discrepancy(self, out1, out2): return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2))) def train(self, epoch, record_file=None): criterion = nn.CrossEntropyLoss().cuda() self.G.train() self.LC.train() self.DC.train() torch.cuda.manual_seed(1) for batch_idx, data in enumerate(self.data_train): img_t = data['T'] img_s = data['S'] label_s = data['S_label'] domain_label_s = torch.zeros(img_s.shape[0]) domain_label_t = torch.ones(img_t.shape[0]) if img_s.size()[0] < self.batch_size or img_t.size( )[0] < self.batch_size: break img_s = img_s.cuda() img_t = img_t.cuda() label_s = Variable(label_s.long().cuda()) domain_label_s = Variable(domain_label_s.long().cuda()) domain_label_t = Variable(domain_label_t.long().cuda()) img_s = Variable(img_s) img_t = Variable(img_t) self.reset_grad() feat_s = self.G(img_s) output_label_s = self.LC(feat_s) loss_label_s = criterion(output_label_s, label_s) loss_label_s.backward() self.opt_g.step() self.opt_lc.step() self.reset_grad() feat_s = self.G(img_s) output_domain_s = self.DC(feat_s) feat_t = self.G(img_t) output_domain_t = self.DC(feat_t) # The objective of the domain classifier is to classify the domain of data accurately. loss_domain_s = criterion(output_domain_s, domain_label_s) loss_domain_t = criterion(output_domain_t, domain_label_t) loss_domain = loss_domain_s + loss_domain_t loss_domain.backward() self.opt_dc.step() self.reset_grad() # One objective of the feature generator is to confuse the domain classifier. feat_s = self.G(img_s) output_domain_s = self.DC(feat_s) feat_t = self.G(img_t) output_domain_t = self.DC(feat_t) loss_domain_s = criterion(output_domain_s, domain_label_s) loss_domain_t = criterion(output_domain_t, domain_label_t) loss_domain = -loss_domain_s - loss_domain_t loss_domain.backward() self.opt_g.step() self.reset_grad() if batch_idx > 500: return batch_idx return batch_idx def test(self, epoch, record_file=None, save_model=False): self.G.eval() self.LC.eval() self.DC.eval() correct = 0.0 size = 0.0 for batch_idx, data in enumerate(self.data_val): img = data['T'] label = data['T_label'] img, label = img.cuda(), label.long().cuda() img, label = Variable(img, volatile=True), Variable(label) # label = label.squeeze() feat = self.G(img) output1 = self.LC(feat) pred1 = output1.data.max(1)[1] k = label.data.size()[0] correct += pred1.eq(label.data).cpu().sum() size += k # if save_model and epoch % self.save_epoch == 0: # torch.save(self.G, # '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) # torch.save(self.C1, # '%s/%s_to_%s_model_epoch%s_C1.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) # torch.save(self.C2, # '%s/%s_to_%s_model_epoch%s_C2.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) if record_file: record = open(record_file, 'a') record.write('%s\n' % (float(correct) / size, )) record.close() return float(correct) / size, epoch, size, self.G, self.LC, self.DC def test_best(self, G, LC, DC): G.eval() LC.eval() DC.eval() test_loss = 0 correct = 0 size = 0 for batch_idx, data in enumerate(self.data_test): img = data['T'] label = data['T_label'] img, label = img.cuda(), label.long().cuda() img, label = Variable(img, volatile=True), Variable(label) label = label.squeeze() feat = G(img) output = LC(feat) test_loss += F.nll_loss(output, label).item() pred1 = output.data.max(1)[1] k = label.data.size()[0] correct += pred1.eq(label.data).cpu().sum() size += k test_loss = test_loss / size print('Best test target acc:', 100.0 * correct.numpy() / size, '%') return correct.numpy() / size def calc_correct_ensemble(self, G, LC, DC, x, y): x, y = x.cuda(), y.long().cuda() x, y = Variable(x, volatile=True), Variable(y) y = y.squeeze() feat = G(x) output = LC(feat) pred = output.data.max(1)[1] correct_num = pred.eq(y.data).cpu().sum() if len(y.data.size()) == 0: print('Error, the size of y is 0!') return 0, 0 size_data = y.data.size()[0] return correct_num, size_data def calc_test_acc(self, G, LC, DC, set_name='T'): correct_all = 0 size_all = 0 for batch_idx, data in enumerate(self.data_test): correct_num, size_data = self.calc_correct_ensemble( G, LC, DC, data[set_name], data[set_name + '_label']) if 0 != size_data: correct_all += correct_num size_all += size_data return correct_all.numpy() / size_all def test_ensemble(self, G, LC, DC): G.eval() LC.eval() DC.eval() acc_s = self.calc_test_acc(G, LC, DC, set_name='S') print('Final test source acc:', 100.0 * acc_s, '%') acc_t = self.calc_test_acc(G, LC, DC, set_name='T') print('Final test target acc:', 100.0 * acc_t, '%') return acc_s, acc_t def input_feature(self): feature_vec = np.zeros(0) label_vec = np.zeros(0) domain_vec = np.zeros(0) for batch_idx, data in enumerate(self.data_test): if data['S'].shape[0] != self.batch_size or \ data['T'].shape[0] != self.batch_size: continue if batch_idx > 6: break feature_s = data['S'].reshape((self.batch_size, -1)) label_s = data['S_label'].squeeze() domain_s = np.zeros(label_s.shape) feature_t = data['T'].reshape((self.batch_size, -1)) label_t = data['T_label'].squeeze() domain_t = np.ones(label_t.shape) feature_c = np.concatenate([feature_s, feature_t]) if 0 == feature_vec.shape[0]: feature_vec = np.copy(feature_c) else: feature_vec = np.r_[feature_vec, feature_c] label_c = np.concatenate([label_s, label_t]) domain_c = np.concatenate([domain_s, domain_t]) label_vec = np.concatenate([label_vec, label_c]) domain_vec = np.concatenate([domain_vec, domain_c]) return feature_vec, label_vec, domain_vec
class Solver(object): def __init__(self, args, batch_size=64, source='source', target='target', learning_rate=0.0002, interval=100, optimizer='adam', num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10, leave_one_num=-1, model_name=''): self.args = args self.batch_size = batch_size self.source = source self.target = target self.num_k = num_k self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.leave_one_num = leave_one_num print('dataset loading') download() self.data_train, self.data_val, self.data_test = dataset_read( source, target, self.batch_size, is_resize=args.is_resize, leave_one_num=self.leave_one_num, dataset=args.dataset, sensor_num=args.sensor_num) print('load finished!') self.G = Generator(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset, sensor_num=args.sensor_num) self.C1 = Classifier(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset) self.C2 = Classifier(source=source, target=target, is_resize=args.is_resize, dataset=args.dataset) if args.eval_only: self.data_val = self.data_test self.G = torch.load(r'checkpoint/best_model_G' + model_name + '.pt') self.C1 = torch.load(r'checkpoint/best_model_C1' + model_name + '.pt') self.C2 = torch.load(r'checkpoint/best_model_C2' + model_name + '.pt') self.G.cuda() self.C1.cuda() self.C2.cuda() self.interval = interval self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.lr = learning_rate def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9): if which_opt == 'momentum': self.opt_g = optim.SGD(self.G.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_c1 = optim.SGD(self.C1.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_c2 = optim.SGD(self.C2.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) if which_opt == 'adam': self.opt_g = optim.Adam(self.G.parameters(), lr=lr, weight_decay=0.0005) self.opt_c1 = optim.Adam(self.C1.parameters(), lr=lr, weight_decay=0.0005) self.opt_c2 = optim.Adam(self.C2.parameters(), lr=lr, weight_decay=0.0005) def reset_grad(self): self.opt_g.zero_grad() self.opt_c1.zero_grad() self.opt_c2.zero_grad() def ent(self, output): return -torch.mean(output * torch.log(output + 1e-6)) def discrepancy(self, out1, out2): return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2))) def train_souce_only(self, epoch, record_file=None): criterion = nn.CrossEntropyLoss().cuda() self.G.train() self.C1.train() self.C2.train() torch.cuda.manual_seed(1) for batch_idx, data in enumerate(self.data_train): img_s = data['S'] label_s = data['S_label'] if img_s.size()[0] < self.batch_size: break img_s = img_s.cuda() label_s = Variable(label_s.long().cuda()) label_s = label_s.squeeze() img_s = Variable(img_s) self.reset_grad() feat_s = self.G(img_s) output_s1 = self.C1(feat_s) output_s2 = self.C2(feat_s) # print(label_s.shape) loss_s1 = criterion(output_s1, label_s) loss_s2 = criterion(output_s2, label_s) loss_s = loss_s1 + loss_s2 loss_s.backward() self.opt_g.step() self.opt_c1.step() self.opt_c2.step() self.reset_grad() if batch_idx > 500: return batch_idx if batch_idx % self.interval == 0: if record_file: record = open(record_file, 'a') record.write('%s \n' % (loss_s.item())) record.close() return batch_idx def train(self, epoch, record_file=None): criterion = nn.CrossEntropyLoss().cuda() self.G.train() self.C1.train() self.C2.train() torch.cuda.manual_seed(1) for batch_idx, data in enumerate(self.data_train): img_t = data['T'] img_s = data['S'] label_s = data['S_label'] if img_s.size()[0] < self.batch_size or img_t.size( )[0] < self.batch_size: break img_s = img_s.cuda() img_t = img_t.cuda() label_s = Variable(label_s.long().cuda()) label_s = label_s.squeeze() img_s = Variable(img_s) img_t = Variable(img_t) self.reset_grad() feat_s = self.G(img_s) output_s1 = self.C1(feat_s) output_s2 = self.C2(feat_s) # print(label_s.shape) loss_s1 = criterion(output_s1, label_s) loss_s2 = criterion(output_s2, label_s) loss_s = loss_s1 + loss_s2 loss_s.backward() self.opt_g.step() self.opt_c1.step() self.opt_c2.step() self.reset_grad() feat_s = self.G(img_s) output_s1 = self.C1(feat_s) output_s2 = self.C2(feat_s) feat_t = self.G(img_t) output_t1 = self.C1(feat_t) output_t2 = self.C2(feat_t) loss_s1 = criterion(output_s1, label_s) loss_s2 = criterion(output_s2, label_s) loss_s = loss_s1 + loss_s2 loss_dis = self.discrepancy(output_t1, output_t2) loss = loss_s - 4 * loss_dis # 1: 92.9; 2: 93.1; 3: 93.5; 5:93.46% loss.backward() self.opt_c1.step() self.opt_c2.step() self.reset_grad() for i in range(self.num_k): # feat_t = self.G(img_t) output_t1 = self.C1(feat_t) output_t2 = self.C2(feat_t) loss_dis = self.discrepancy(output_t1, output_t2) loss_dis.backward() self.opt_g.step() self.reset_grad() if batch_idx > 500: return batch_idx if batch_idx % self.interval == 0: if record_file: record = open(record_file, 'a') record.write( '%s %s %s\n' % (loss_dis.item(), loss_s1.item(), loss_s2.item())) record.close() return batch_idx def train_onestep(self, epoch, record_file=None): criterion = nn.CrossEntropyLoss().cuda() self.G.train() self.C1.train() self.C2.train() torch.cuda.manual_seed(1) for batch_idx, data in enumerate(self.data_train): img_t = data['T'] img_s = data['S'] label_s = data['S_label'] if img_s.size()[0] < self.batch_size or img_t.size( )[0] < self.batch_size: break img_s = img_s.cuda() img_t = img_t.cuda() label_s = Variable(label_s.long().cuda()) img_s = Variable(img_s) img_t = Variable(img_t) self.reset_grad() feat_s = self.G(img_s) output_s1 = self.C1(feat_s) output_s2 = self.C2(feat_s) loss_s1 = criterion(output_s1, label_s) loss_s2 = criterion(output_s2, label_s) loss_s = loss_s1 + loss_s2 loss_s.backward(retain_variables=True) feat_t = self.G(img_t) self.C1.set_lambda(1.0) self.C2.set_lambda(1.0) output_t1 = self.C1(feat_t, reverse=True) output_t2 = self.C2(feat_t, reverse=True) loss_dis = -self.discrepancy(output_t1, output_t2) #loss_dis.backward() self.opt_c1.step() self.opt_c2.step() self.opt_g.step() self.reset_grad() if batch_idx > 500: return batch_idx if batch_idx % self.interval == 0: if record_file: record = open(record_file, 'a') record.write( '%s %s %s\n' % (loss_dis.data[0], loss_s1.data[0], loss_s2.data[0])) record.close() return batch_idx def test(self, epoch, record_file=None, save_model=False): self.G.eval() self.C1.eval() self.C2.eval() test_loss = 0.0 correct1 = 0.0 correct2 = 0.0 correct3 = 0.0 size = 0.0 for batch_idx, data in enumerate(self.data_val): img = data['T'] label = data['T_label'] img, label = img.cuda(), label.long().cuda() img, label = Variable(img, volatile=True), Variable(label) # label = label.squeeze() feat = self.G(img) output1 = self.C1(feat) output2 = self.C2(feat) test_loss += F.nll_loss(output1, label).item() output_ensemble = output1 + output2 pred1 = output1.data.max(1)[1] pred2 = output2.data.max(1)[1] pred_ensemble = output_ensemble.data.max(1)[1] k = label.data.size()[0] correct1 += pred1.eq(label.data).cpu().sum() correct2 += pred2.eq(label.data).cpu().sum() correct3 += pred_ensemble.eq(label.data).cpu().sum() size += k test_loss = test_loss / size # if save_model and epoch % self.save_epoch == 0: # torch.save(self.G, # '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) # torch.save(self.C1, # '%s/%s_to_%s_model_epoch%s_C1.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) # torch.save(self.C2, # '%s/%s_to_%s_model_epoch%s_C2.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) if record_file: record = open(record_file, 'a') record.write('%s %s %s\n' % (float(correct1) / size, float(correct2) / size, float(correct3) / size)) record.close() return float(correct3) / size, epoch, size, self.G, self.C1, self.C2 def test_best(self, G, C1, C2): G.eval() C1.eval() C2.eval() test_loss = 0 correct1 = 0 correct2 = 0 correct3 = 0 size = 0 for batch_idx, data in enumerate(self.data_test): img = data['T'] label = data['T_label'] img, label = img.cuda(), label.long().cuda() img, label = Variable(img, volatile=True), Variable(label) label = label.squeeze() feat = G(img) output1 = C1(feat) output2 = C2(feat) test_loss += F.nll_loss(output1, label).item() output_ensemble = output1 + output2 pred1 = output1.data.max(1)[1] pred2 = output2.data.max(1)[1] pred_ensemble = output_ensemble.data.max(1)[1] k = label.data.size()[0] correct1 += pred1.eq(label.data).cpu().sum() correct2 += pred2.eq(label.data).cpu().sum() correct3 += pred_ensemble.eq(label.data).cpu().sum() size += k test_loss = test_loss / size print('Best test target acc:', 100.0 * correct3.numpy() / size, '%') return correct3.numpy() / size def calc_correct_ensemble(self, G, C1, C2, x, y): x, y = x.cuda(), y.long().cuda() x, y = Variable(x, volatile=True), Variable(y) y = y.squeeze() feat = G(x) output1 = C1(feat) output2 = C2(feat) output_ensemble = output1 + output2 pred_ensemble = output_ensemble.data.max(1)[1] correct_num = pred_ensemble.eq(y.data).cpu().sum() if len(y.data.size()) == 0: return 0, 0 size_data = y.data.size()[0] return correct_num, size_data def calc_test_acc(self, G, C1, C2, set_name='T'): correct_all = 0 size_all = 0 for batch_idx, data in enumerate(self.data_test): correct_num, size_data = self.calc_correct_ensemble( G, C1, C2, data[set_name], data[set_name + '_label']) if 0 != size_data: correct_all += correct_num size_all += size_data return correct_all.numpy() / size_all def test_ensemble(self, G, C1, C2): G.eval() C1.eval() C2.eval() acc_s = self.calc_test_acc(G, C1, C2, set_name='S') print('Final test source acc:', 100.0 * acc_s, '%') acc_t = self.calc_test_acc(G, C1, C2, set_name='T') print('Final test target acc:', 100.0 * acc_t, '%') return acc_s, acc_t def input_feature(self): feature_vec = np.zeros(0) label_vec = np.zeros(0) domain_vec = np.zeros(0) for batch_idx, data in enumerate(self.data_test): if data['S'].shape[0] != self.batch_size or \ data['T'].shape[0] != self.batch_size: continue if batch_idx > 6: break feature_s = data['S'].reshape((self.batch_size, -1)) label_s = data['S_label'].squeeze() domain_s = np.zeros(label_s.shape) feature_t = data['T'].reshape((self.batch_size, -1)) label_t = data['T_label'].squeeze() domain_t = np.ones(label_t.shape) feature_c = np.concatenate([feature_s, feature_t]) if 0 == feature_vec.shape[0]: feature_vec = np.copy(feature_c) else: feature_vec = np.r_[feature_vec, feature_c] label_c = np.concatenate([label_s, label_t]) domain_c = np.concatenate([domain_s, domain_t]) label_vec = np.concatenate([label_vec, label_c]) domain_vec = np.concatenate([domain_vec, domain_c]) return feature_vec, label_vec, domain_vec def tsne_feature(self): self.G.eval() feature_vec = torch.tensor(()).cuda() label_vec = np.zeros(0) domain_vec = np.zeros(0) for batch_idx, data in enumerate(self.data_test): if data['S'].shape[0] != self.batch_size or \ data['T'].shape[0] != self.batch_size: continue if batch_idx > 6: break img_s = data['S'] label_s = data['S_label'].squeeze() domain_s = np.zeros(label_s.shape) img_t = data['T'] label_t = data['T_label'].squeeze() domain_t = np.ones(label_t.shape) img_c = np.vstack([img_s, img_t]) img_c = torch.from_numpy(img_c) img_c = img_c.cuda() img_c = Variable(img_c, volatile=True) feat_c = self.G(img_c) feature_vec = torch.cat((feature_vec, feat_c), 0) label_c = np.concatenate([label_s, label_t]) domain_c = np.concatenate([domain_s, domain_t]) label_vec = np.concatenate([label_vec, label_c]) domain_vec = np.concatenate([domain_vec, domain_c]) return feature_vec.cpu().detach().numpy(), label_vec, domain_vec
class Solver(object): def __init__(self, args, batch_size=256, source='mnist', target='usps', learning_rate=0.02, interval=100, optimizer='momentum', num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10): self.batch_size = batch_size self.source = source self.target = target self.num_k = num_k self.checkpoint_dir = checkpoint_dir self.save_epoch = save_epoch self.use_abs_diff = args.use_abs_diff self.all_use = all_use self.alpha = args.alpha self.beta = args.beta if self.source == 'svhn': self.scale = True else: self.scale = False print('dataset loading') self.datasets, self.dataset_test = dataset_read(source, target, self.batch_size, scale=self.scale, all_use=self.all_use) print('load finished!') self.G = Generator(source=source, target=target) self.C = Classifier(source=source, target=target) if args.eval_only: self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, self.checkpoint_dir, args.resume_epoch)) self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch)) self.G.cuda() self.C.cuda() self.interval = interval self.set_optimizer(which_opt=optimizer, lr=learning_rate) self.lr = learning_rate def set_optimizer(self, which_opt='momentum', lr=0.02, momentum=0.9): if which_opt == 'momentum': self.opt_g = optim.SGD(self.G.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) self.opt_c = optim.SGD(self.C.parameters(), lr=lr, weight_decay=0.0005, momentum=momentum) if which_opt == 'adam': self.opt_g = optim.Adam(self.G.parameters(), lr=lr, weight_decay=0.0005) self.opt_c = optim.Adam(self.C.parameters(), lr=lr, weight_decay=0.0005) def reset_grad(self): self.opt_g.zero_grad() self.opt_c.zero_grad() def get_entropy_loss(self, p_softmax): mask = p_softmax.ge(0.000001) mask_out = torch.masked_select(p_softmax, mask) entropy = -(torch.sum(mask_out * torch.log(mask_out))) return 0.1 * (entropy / float(p_softmax.size(0))) def discrepancy(self, out1, out2): return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2))) def train(self, epoch, record_file=None): criterion = nn.CrossEntropyLoss().cuda() # initialze a L1 loss for DAL criterionDAL = nn.L1Loss().cuda() self.G.train() self.C.train() torch.cuda.manual_seed(1) Tensor = torch.cuda.FloatTensor for batch_idx, data in enumerate(self.datasets): img_t = data['T'] img_s = data['S'] label_s = data['S_label'] if img_s.size()[0] < self.batch_size or img_t.size( )[0] < self.batch_size: break img_s = img_s.cuda() img_t = img_t.cuda() label_s = Variable(label_s.long().cuda()) # for mnist or usps (source) zn = Variable(Tensor(np.random.normal(0, 1, (4096, 48)))) # for svhn (source) #zn = Variable(Tensor(np.random.normal(0,1, (16384, 128)))) img_s = Variable(img_s) img_t = Variable(img_t) self.reset_grad() feat_s = self.G(img_s) output_s = self.C(feat_s) feat_t = self.G(img_t) output_t = self.C(feat_t) # for mnist or usps (source) feat_s_kl = feat_s.view(-1, 48) # for svhn (source) #feat_s_kl = feat_s.view(-1,128) loss_kld_s = F.kl_div(F.log_softmax(feat_s_kl), F.softmax(zn)) loss_s = criterion(output_s, label_s) loss = loss_s + self.alpha * loss_kld_s loss.backward() self.opt_g.step() self.opt_c.step() self.reset_grad() feat_t = self.G(img_t) output_t = self.C(feat_t) feat_t_recon = self.G(img_t, is_deconv=True) feat_zn_recon = self.G.decode(zn) # DAL loss_dal = criterionDAL(feat_t_recon, feat_zn_recon) # entropy loss t_prob = F.softmax(output_t) t_entropy_loss = self.get_entropy_loss(t_prob) loss = t_entropy_loss + self.beta * loss_dal loss.backward() self.opt_g.step() self.opt_c.step() self.reset_grad() if batch_idx > 500: return batch_idx if batch_idx % self.interval == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t Entropy: {:.6f}' .format(epoch, batch_idx, 100, 100. * batch_idx / 70000, loss_s.item(), t_entropy_loss.item())) if record_file: record = open(record_file, 'a') record.write('%s %s\n' % (t_entropy_loss.item(), loss_s.item())) record.close() torch.save( self.G, '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) return batch_idx def test(self, epoch, record_file=None, save_model=False): self.G.eval() self.C.eval() test_loss = 0 correct = 0 size = 0 for batch_idx, data in enumerate(self.dataset_test): img = data['T'] label = data['T_label'] img, label = img.cuda(), label.long().cuda() img, label = Variable(img, volatile=True), Variable(label) feat = self.G(img) output = self.C(feat) test_loss += F.nll_loss(output, label).item() pred = output.data.max(1)[1] k = label.data.size()[0] correct += pred.eq(label.data).cpu().sum() size += k test_loss = test_loss / size print( '\nTest set: Average loss: {:.4f}, Accuracy C: {}/{} ({:.0f}%) \n'. format(test_loss, correct, size, 100. * correct / size)) if save_model and epoch % self.save_epoch == 0: torch.save( self.G, '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) torch.save( self.C, '%s/%s_to_%s_model_epoch%s_C.pt' % (self.checkpoint_dir, self.source, self.target, epoch)) if record_file: record = open(record_file, 'a') print('recording %s', record_file) record.write('%s\n' % (float(correct) / size)) record.close()