def main(): # Create output directory path_output = './checkpoints/' if not os.path.exists(path_output): os.makedirs(path_output) # Hyperparameters, to change epochs = 30 batch_size = 8 alpha = 1 # it's the trade-off parameter of loss function, what values should it take? gamma = 1 # Source domains name root = 'data/' source1 = 'real' source2 = 'sketch' source3 = 'infograph' target = 'quickdraw' # Dataloader dataset_s1 = dataset.DA(dir=root, name=source1, img_size=(224, 224), train=True) dataset_s2 = dataset.DA(dir=root, name=source2, img_size=(224, 224), train=True) dataset_s3 = dataset.DA(dir=root, name=source3, img_size=(224, 224), train=True) dataset_t = dataset.DA(dir=root, name=target, img_size=(224, 224), train=True) dataset_val = dataset.DA(dir=root, name=target, img_size=(224, 224), train=True, real_val=False) dataloader_s1 = DataLoader(dataset_s1, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_s2 = DataLoader(dataset_s2, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_s3 = DataLoader(dataset_s3, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_t = DataLoader(dataset_t, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=2) len_data = min(len(dataset_s1), len(dataset_s2), len(dataset_s3), len(dataset_t)) # length of "shorter" domain len_dataloader = min(len(dataloader_s1), len(dataloader_s2), len(dataloader_s3), len(dataloader_t)) # Define networks feature_extractor = models.feature_extractor() classifier_1 = models.class_classifier() classifier_2 = models.class_classifier() classifier_3 = models.class_classifier() classifier_1.apply(weight_init) classifier_2.apply(weight_init) classifier_3.apply(weight_init) discriminator = models.discriminator() discriminator.apply(weight_init) if torch.cuda.is_available(): feature_extractor = feature_extractor.cuda() classifier_1 = classifier_1.cuda() classifier_2 = classifier_2.cuda() classifier_3 = classifier_3.cuda() discriminator = discriminator.cuda() # Define loss cl_loss = nn.CrossEntropyLoss() disc_loss = nn.NLLLoss() # Optimizers # Change the LR optimizer_features = SGD(feature_extractor.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4) optimizer_classifier = SGD(([{ 'params': classifier_1.parameters() }, { 'params': classifier_2.parameters() }, { 'params': classifier_3.parameters() }]), lr=0.002, momentum=0.9, weight_decay=5e-4) optimizer_discriminator = SGD(([ { 'params': discriminator.parameters() }, ]), lr=0.002, momentum=0.9, weight_decay=5e-4) # Lists training_loss = [] train_loss = [] train_class_loss = [] train_domain_loss = [] val_class_loss = [] val_domain_loss = [] acc_on_target = [] best_acc = 0.0 for epoch in range(epochs): epochTic = timeit.default_timer() tot_loss = 0.0 tot_c_loss = 0.0 tot_d_loss = 0.0 tot_val_c_loss = 0.0 tot_val_d_loss = 0.0 w1_mean = 0.0 w2_mean = 0.0 w3_mean = 0.0 feature_extractor.train() classifier_1.train(), classifier_2.train(), classifier_3.train() discriminator.train() if epoch + 1 == 5: optimizer_classifier = SGD(([{ 'params': classifier_1.parameters() }, { 'params': classifier_2.parameters() }, { 'params': classifier_3.parameters() }]), lr=0.001, momentum=0.9, weight_decay=5e-4) optimizer_discriminator = SGD( ([{ 'params': discriminator.parameters() }]), lr=0.001, momentum=0.9, weight_decay=5e-4) if epoch + 1 == 10: optimizer_classifier = SGD(([{ 'params': classifier_1.parameters() }, { 'params': classifier_2.parameters() }, { 'params': classifier_3.parameters() }]), lr=0.0001, momentum=0.9, weight_decay=5e-4) optimizer_discriminator = SGD( ([{ 'params': discriminator.parameters() }]), lr=0.0001, momentum=0.9, weight_decay=5e-4) print('*************************************************') for i, (data_1, data_2, data_3, data_t) in enumerate( zip(dataloader_s1, dataloader_s2, dataloader_s3, dataloader_t)): p = float(i + epoch * len_data) / epochs / len_data alpha = 2. / (1. + np.exp(-10 * p)) - 1 img1, lb1 = data_1 img2, lb2 = data_2 img3, lb3 = data_3 imgt, _ = data_t # Prepare data cur_batch = min(img1.shape[0], img2.shape[0], img3.shape[0], imgt.shape[0]) img1, lb1 = Variable(img1[0:cur_batch, :, :, :]).cuda(), Variable( lb1[0:cur_batch]).cuda() img2, lb2 = Variable(img2[0:cur_batch, :, :, :]).cuda(), Variable( lb2[0:cur_batch]).cuda() img3, lb3 = Variable(img3[0:cur_batch, :, :, :]).cuda(), Variable( lb3[0:cur_batch]).cuda() imgt = Variable(imgt[0:cur_batch, :, :, :]).cuda() # Forward optimizer_features.zero_grad() optimizer_classifier.zero_grad() optimizer_discriminator.zero_grad() # Extract Features ft1 = feature_extractor(img1) ft2 = feature_extractor(img2) ft3 = feature_extractor(img3) ft_t = feature_extractor(imgt) # Train the discriminator ds_s1 = discriminator(torch.cat((ft1, ft2, ft3)), alpha) ds_t = discriminator(ft_t, alpha) # Class Prediction cl1 = classifier_1(ft1) cl2 = classifier_2(ft2) cl3 = classifier_3(ft3) # Compute the "discriminator loss" ds_label = torch.zeros(cur_batch * 3).long() dt_label = torch.ones(cur_batch).long() d_s = disc_loss(ds_s1, ds_label.cuda()) d_t = disc_loss(ds_t, dt_label.cuda()) # Cross entropy loss l1 = cl_loss(cl1, lb1) l2 = cl_loss(cl2, lb2) l3 = cl_loss(cl3, lb3) # Classifier Weight total_class_loss = 1 / l1 + 1 / l2 + 1 / l3 w1 = (1 / l1) / total_class_loss w2 = (1 / l2) / total_class_loss w3 = (1 / l3) / total_class_loss w1_mean += w1 w2_mean += w2 w3_mean += w3 # total loss Class_loss = l1 + l2 + l3 Domain_loss = gamma * (d_s + d_t) loss = Class_loss + Domain_loss loss.backward() optimizer_features.step() optimizer_classifier.step() optimizer_discriminator.step() tot_loss += loss.item() * cur_batch tot_c_loss += Class_loss.item() tot_d_loss += Domain_loss.item() # Progress indicator print('\rTraining... Progress: %.1f %%' % (100 * (i + 1) / len_dataloader), end='') # Save Class loss and Domain loss if i % 50 == 0: train_class_loss.append(tot_c_loss / (i + 1)) train_domain_loss.append(tot_d_loss / (i + 1)) train_loss.append(tot_loss / (i + 1) / cur_batch) tot_t_loss = tot_loss / (len_data) training_loss.append(tot_t_loss) w1_mean /= len_dataloader w2_mean /= len_dataloader w3_mean /= len_dataloader #print(w1_mean,w2_mean,w3_mean) print('\rEpoch [%d/%d], Training loss: %.4f' % (epoch + 1, epochs, tot_t_loss), end='\n') #################################################################################################################### # Compute the accuracy at the end of each epoch feature_extractor.eval() classifier_1.eval(), classifier_2.eval(), classifier_3.eval() discriminator.eval() tot_acc = 0 with torch.no_grad(): for i, (imgt, lbt) in enumerate(dataloader_val): cur_batch = imgt.shape[0] imgt = imgt.cuda() lbt = lbt.cuda() # Forward the test images ft_t = feature_extractor(imgt) pred1 = classifier_1(ft_t) pred2 = classifier_2(ft_t) pred3 = classifier_3(ft_t) val_ds_t = discriminator(ft_t, alpha) # Compute class loss val_l1 = cl_loss(pred1, lbt) val_l2 = cl_loss(pred2, lbt) val_l3 = cl_loss(pred3, lbt) val_CE_loss = val_l1 + val_l2 + val_l3 # Compute domain loss val_dt_label = torch.ones(cur_batch).long() val_d_t = disc_loss(val_ds_t, val_dt_label.cuda()) # Compute accuracy output = pred1 * w1_mean + pred2 * w2_mean + pred3 * w3_mean _, pred = torch.max(output, dim=1) correct = pred.eq(lbt.data.view_as(pred)) accuracy = torch.mean(correct.type(torch.FloatTensor)) tot_acc += accuracy.item() * cur_batch # total loss tot_val_c_loss += val_CE_loss.item() tot_val_d_loss += val_d_t.item() # Progress indicator print('\rValidation... Progress: %.1f %%' % (100 * (i + 1) / len(dataloader_val)), end='') # Save validation loss if i % 50 == 0: val_class_loss.append(tot_val_c_loss / (i + 1)) val_domain_loss.append(tot_val_d_loss / (i + 1)) tot_t_acc = tot_acc / (len(dataset_val)) # Print acc_on_target.append(tot_t_acc) print('\rEpoch [%d/%d], Accuracy on target: %.4f' % (epoch + 1, epochs, tot_t_acc), end='\n') # Save every save_interval if best_acc < tot_t_acc: torch.save( { 'epoch': epoch, 'feature_extractor': feature_extractor.state_dict(), '{}_classifier'.format(source1): classifier_1.state_dict(), '{}_classifier'.format(source2): classifier_2.state_dict(), '{}_classifier'.format(source3): classifier_3.state_dict(), 'discriminator': discriminator.state_dict(), 'features_optimizer': optimizer_features.state_dict(), 'classifier_optimizer': optimizer_classifier.state_dict(), 'loss': training_loss, '{}_weight'.format(source1): w1_mean, '{}_weight'.format(source2): w2_mean, '{}_weight'.format(source3): w3_mean, }, os.path.join(path_output, target + '-{}-deming.pth'.format(epoch))) print('Saved best model!') best_acc = tot_t_acc # Pirnt elapsed time per epoch epochToc = timeit.default_timer() (t_min, t_sec) = divmod((epochToc - epochTic), 60) print('Elapsed time is: %d min: %d sec' % (t_min, t_sec)) # Save training loss and accuracy on target (if not 'real') pkl.dump(train_loss, open('{}total_loss_{}.p'.format(path_output, target), 'wb')) pkl.dump(train_class_loss, open('{}class_loss_{}.p'.format(path_output, target), 'wb')) pkl.dump(train_domain_loss, open('{}domain_loss_{}.p'.format(path_output, target), 'wb')) pkl.dump( acc_on_target, open('{}target_accuracy_{}.p'.format(path_output, target), 'wb')) pkl.dump( val_class_loss, open('{}val_class_loss_{}.p'.format(path_output, target), 'wb')) pkl.dump( val_domain_loss, open('{}val_domain_loss_{}.p'.format(path_output, target), 'wb'))
def main(): # Create output directory path_output = './checkpoints/' if not os.path.exists(path_output): os.makedirs(path_output) # Hyperparameters, to change epochs = 50 batch_size = 8 alpha = 1 # it's the trade-off parameter of loss function, what values should it take? # Source domains name save_interval = 10 # save every 10 epochs root = 'data/' source1 = 'sketch' source2 = 'quickdraw' source3 = 'infograph' target = 'real' # Dataloader dataset_s1 = dataset.DA(dir=root, name=source1, img_size=(224, 224), train=True) dataset_s2 = dataset.DA(dir=root, name=source2, img_size=(224, 224), train=True) dataset_s3 = dataset.DA(dir=root, name=source3, img_size=(224, 224), train=True) if target == 'real': tmp = os.path.join(root, 'test') dataset_t = dataset.DA_test(dir=tmp, img_size=(224, 224)) else: dataset_t = dataset.DA(dir=root, name=target, img_size=(224, 224), train=False) dataloader_s1 = DataLoader(dataset_s1, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_s2 = DataLoader(dataset_s2, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_s3 = DataLoader(dataset_s3, batch_size=batch_size, shuffle=True, num_workers=2) dataloader_t = DataLoader(dataset_t, batch_size=batch_size, shuffle=True, num_workers=2) len_data = min(len(dataset_s1), len(dataset_s2), len(dataset_s3), len(dataset_t)) # length of "shorter" domain # Define networks feature_extractor = models.feature_extractor() classifier_1 = models.class_classifier() classifier_2 = models.class_classifier() classifier_3 = models.class_classifier() # Weight initialization classifier_1.apply(weight_init) classifier_2.apply(weight_init) classifier_3.apply(weight_init) if torch.cuda.is_available(): feature_extractor = feature_extractor.cuda() classifier_1 = classifier_1.cuda() classifier_2 = classifier_2.cuda() classifier_3 = classifier_3.cuda() # Define loss mom_loss = momentumLoss() cl_loss = nn.CrossEntropyLoss() # Optimizers # Change the LR optimizer_features = Adam(feature_extractor.parameters(), lr=0.0001) optimizer_classifier = Adam(([{ 'params': classifier_1.parameters() }, { 'params': classifier_2.parameters() }, { 'params': classifier_3.parameters() }]), lr=0.002) # Lists train_loss = [] acc_on_target = [] for epoch in range(epochs): tot_loss = 0.0 feature_extractor.train() classifier_1.train(), classifier_2.train(), classifier_3.train() for i, (data_1, data_2, data_3, data_t) in enumerate( zip(dataloader_s1, dataloader_s2, dataloader_s3, dataloader_t)): img1, lb1 = data_1 img2, lb2 = data_2 img3, lb3 = data_3 if target == 'real': imgt = data_t else: imgt, _ = data_t # Prepare data cur_batch = min(img1.shape[0], img2.shape[0], img3.shape[0], imgt.shape[0]) img1, lb1 = Variable(img1[0:cur_batch, :, :, :]).cuda(), Variable( lb1[0:cur_batch]).cuda() img2, lb2 = Variable(img2[0:cur_batch, :, :, :]).cuda(), Variable( lb2[0:cur_batch]).cuda() img3, lb3 = Variable(img3[0:cur_batch, :, :, :]).cuda(), Variable( lb3[0:cur_batch]).cuda() imgt = Variable(imgt[0:cur_batch, :, :, :]).cuda() # Forward optimizer_features.zero_grad() optimizer_classifier.zero_grad() # Extract Features ft1 = feature_extractor(img1) ft2 = feature_extractor(img2) ft3 = feature_extractor(img3) ft_t = feature_extractor(imgt) # Class Prediction cl1 = classifier_1(ft1) cl2 = classifier_2(ft2) cl3 = classifier_3(ft3) # Compute "momentum loss" loss_mom = mom_loss(ft1, ft2, ft3, ft_t) # Cross entropy loss l1 = cl_loss(cl1, lb1) l2 = cl_loss(cl2, lb2) l3 = cl_loss(cl3, lb3) # total loss loss = l1 + l2 + l3 + alpha * loss_mom #print(loss_mom,(l1+l2+l3)) loss.backward() optimizer_features.step() optimizer_classifier.step() tot_loss += loss.item() * cur_batch tot_t_loss = tot_loss / (len_data) # Print train_loss.append(tot_t_loss) print('*************************************************') print('Epoch [%d/%d], Training loss: %.4f' % (epoch + 1, epochs, tot_t_loss)) #################################################################################################################### # Compute the accuracy at the end of each epoch if target != 'real': feature_extractor.eval() classifier_1.eval(), classifier_2.eval(), classifier_3.eval() tot_acc = 0 with torch.no_grad(): for i, (imgt, lbt) in enumerate(dataloader_t): cur_batch = imgt.shape[0] imgt = imgt.cuda() lbt = lbt.cuda() # Forward the test images ft_t = feature_extractor(imgt) pred1 = classifier_1(ft_t) pred2 = classifier_2(ft_t) pred3 = classifier_3(ft_t) # Compute accuracy output = torch.mean(torch.stack((pred1, pred2, pred3)), 0) _, pred = torch.max(output, dim=1) correct = pred.eq(lbt.data.view_as(pred)) accuracy = torch.mean(correct.type(torch.FloatTensor)) tot_acc += accuracy.item() * cur_batch tot_t_acc = tot_acc / (len(dataset_t)) # Print acc_on_target.append(tot_t_acc) print('Epoch [%d/%d], Accuracy on target: %.4f' % (epoch + 1, epochs, tot_t_acc)) # Save every save_interval if epoch % save_interval == 0 or epoch == epochs - 1: torch.save( { 'epoch': epoch, 'feature_extractor': feature_extractor.state_dict(), '{}_classifier'.format(source1): classifier_1.state_dict(), '{}_classifier'.format(source2): classifier_2.state_dict(), '{}_classifier'.format(source3): classifier_3.state_dict(), 'features_optimizer': optimizer_features.state_dict(), 'classifier_optimizer': optimizer_classifier.state_dict(), 'loss': tot_loss, }, os.path.join(path_output, target + '-{}.pth'.format(epoch))) # Save training loss and accuracy on target (if not 'real') pkl.dump(train_loss, open('{}train_loss.p'.format(path_output), 'wb')) if target != 'real': pkl.dump(acc_on_target, open('{}target_accuracy.p'.format(path_output), 'wb'))
def train(opt): from tensorboardX import SummaryWriter writer = SummaryWriter(path_output) source1, source2, source3, target = taskSelect(opt.target) dataset_s1 = dataset.DA(dir=root, name=source1, img_size=(224, 224), train=True) dataset_s2 = dataset.DA(dir=root, name=source2, img_size=(224, 224), train=True) dataset_s3 = dataset.DA(dir=root, name=source3, img_size=(224, 224), train=True) dataset_t = dataset.DA(dir=root, name=target, img_size=(224, 224), train=True) dataset_tt = dataset.DA(dir=root, name=target, img_size=(224,224), train=False,real_val=False) dataloader_s1 = DataLoader(dataset_s1, batch_size=opt.bs, shuffle=True, num_workers=2) dataloader_s2 = DataLoader(dataset_s2, batch_size=opt.bs, shuffle=True, num_workers=2) dataloader_s3 = DataLoader(dataset_s3, batch_size=opt.bs, shuffle=True, num_workers=2) dataloader_t = DataLoader(dataset_t, batch_size=opt.bs, shuffle=True, num_workers=2) dataloader_tt = DataLoader(dataset_tt, batch_size=opt.bs, shuffle=False, num_workers=2) # dataset_s1 = dataset.DA(dir=root, name=source1, img_size=(224, 224), train=True) # dataset_s2 = dataset.DA(dir=root, name=source2, img_size=(224, 224), train=True) # dataset_s3 = dataset.DA(dir=root, name=source3, img_size=(224, 224), train=True) # dataset_t = dataset.DA(dir=root, name=target, img_size=(224, 224), train=True) # if target == 'real': # tmp = os.path.join(root, 'test') # dataset_tt = dataset.DA_test(dir=tmp, img_size=(224,224)) # else: # dataset_tt = dataset.DA(dir=root, name=target, img_size=(224, 224), train=False) # dataloader_s1 = DataLoader(dataset_s1, batch_size=opt.bs, shuffle=True, num_workers=2) # dataloader_s2 = DataLoader(dataset_s2, batch_size=opt.bs, shuffle=True, num_workers=2) # dataloader_s3 = DataLoader(dataset_s3, batch_size=opt.bs, shuffle=True, num_workers=2) # dataloader_t = DataLoader(dataset_t, batch_size=opt.bs, shuffle=True, num_workers=2) # dataloader_tt = DataLoader(dataset_tt, batch_size=opt.bs, shuffle=False, num_workers=2) len_data = min(len(dataset_s1), len(dataset_s2), len(dataset_s3), len(dataset_t)) # length of "shorter" domain len_bs = min(len(dataloader_s1), len(dataloader_s2), len(dataloader_s3), len(dataloader_t)) # Define networks feature_extractor = models.feature_extractor() classifier_1 = models.class_classifier() classifier_2 = models.class_classifier() classifier_3 = models.class_classifier() classifier_1_ = models.class_classifier() classifier_2_ = models.class_classifier() classifier_3_ = models.class_classifier() # if torch.cuda.is_available(): feature_extractor = feature_extractor.to(device) classifier_1 = classifier_1.to(device).apply(weight_init) classifier_2 = classifier_2.to(device).apply(weight_init) classifier_3 = classifier_3.to(device).apply(weight_init) classifier_1_ = classifier_1_.to(device).apply(weight_init) classifier_2_ = classifier_2_.to(device).apply(weight_init) classifier_3_ = classifier_3_.to(device).apply(weight_init) # Define loss mom_loss = momentumLoss() cl_loss = nn.CrossEntropyLoss() disc_loss = discrepancyLoss() # Optimizers # Change the LR optimizer_features = SGD(feature_extractor.parameters(), lr=0.0001,momentum=0.9,weight_decay=5e-4) optimizer_classifier = SGD(([{'params': classifier_1.parameters()}, {'params': classifier_2.parameters()}, {'params': classifier_3.parameters()}]), lr=0.002,momentum=0.9,weight_decay=5e-4) optimizer_classifier_ = SGD(([{'params': classifier_1_.parameters()}, {'params': classifier_2_.parameters()}, {'params': classifier_3_.parameters()}]), lr=0.002,momentum=0.9,weight_decay=5e-4) # optimizer_features = SGD(feature_extractor.parameters(), lr=0.0001) # optimizer_classifier = Adam(([{'params': classifier_1.parameters()}, # {'params': classifier_2.parameters()}, # {'params': classifier_3.parameters()}]), lr=0.002) # optimizer_classifier_ = Adam(([{'params': classifier_1_.parameters()}, # {'params': classifier_2_.parameters()}, # {'params': classifier_3_.parameters()}]), lr=0.002) if opt.pretrain is not None: state = torch.load(opt.pretrain) feature_extractor.load_state_dict(state['feature_extractor']) classifier_1.load_state_dict(state['{}_classifier'.format(source1)]) classifier_2.load_state_dict(state['{}_classifier'.format(source2)]) classifier_3.load_state_dict(state['{}_classifier'.format(source3)]) classifier_1_.load_state_dict(state['{}_classifier_'.format(source1)]) classifier_2_.load_state_dict(state['{}_classifier_'.format(source2)]) classifier_3_.load_state_dict(state['{}_classifier_'.format(source3)]) # Lists train_loss = [] acc_on_target = [] tot_loss, tot_clf_loss, tot_mom_loss, tot_s2_loss, tot_s3_loss = 0.0, 0.0, 0.0, 0.0, 0.0 n_samples, iteration = 0, 0 tot_correct = [0, 0, 0, 0, 0, 0] saved_time = time.time() feature_extractor.train() classifier_1.train(), classifier_2.train(), classifier_3.train() classifier_1_.train(), classifier_2_.train(), classifier_3_.train() for epoch in range(opt.ep): if epoch+1 == 5: optimizer_classifier = SGD(([{'params': classifier_1.parameters()}, {'params': classifier_2.parameters()}, {'params': classifier_3.parameters()}]), lr=0.001,momentum=0.9,weight_decay=5e-4) optimizer_classifier_ = SGD(([{'params': classifier_1_.parameters()}, {'params': classifier_2_.parameters()}, {'params': classifier_3_.parameters()}]), lr=0.001,momentum=0.9,weight_decay=5e-4) if epoch+1 == 10: optimizer_classifier = SGD(([{'params': classifier_1.parameters()}, {'params': classifier_2.parameters()}, {'params': classifier_3.parameters()}]), lr=0.0001,momentum=0.9,weight_decay=5e-4) optimizer_classifier_ = SGD(([{'params': classifier_1_.parameters()}, {'params': classifier_2_.parameters()}, {'params': classifier_3_.parameters()}]), lr=0.0001,momentum=0.9,weight_decay=5e-4) for i, (data_1, data_2, data_3, data_t) in enumerate(zip(dataloader_s1, dataloader_s2, dataloader_s3, dataloader_t)): img1, lb1 = data_1 img2, lb2 = data_2 img3, lb3 = data_3 imgt, _ = data_t # Prepare data cur_batch = min(img1.shape[0], img2.shape[0], img3.shape[0], imgt.shape[0]) # print(i, cur_batch) img1, lb1 = Variable(img1[0:cur_batch,:,:,:]).to(device), Variable(lb1[0:cur_batch]).to(device) img2, lb2 = Variable(img2[0:cur_batch,:,:,:]).to(device), Variable(lb2[0:cur_batch]).to(device) img3, lb3 = Variable(img3[0:cur_batch,:,:,:]).to(device), Variable(lb3[0:cur_batch]).to(device) imgt = Variable(imgt[0:cur_batch,:,:,:]).to(device) ### STEP 1 ### train G and C pairs # Forward optimizer_features.zero_grad() optimizer_classifier.zero_grad() optimizer_classifier_.zero_grad() # Extract Features ft1 = feature_extractor(img1) ft2 = feature_extractor(img2) ft3 = feature_extractor(img3) ft_t = feature_extractor(imgt) # Class Prediction [bs, 345] cl1, cl1_ = classifier_1(ft1), classifier_1_(ft1) cl2, cl2_ = classifier_2(ft2), classifier_2_(ft2) cl3, cl3_ = classifier_3(ft3), classifier_3_(ft3) # Compute "momentum loss" loss_mom = mom_loss(ft1, ft2, ft3, ft_t) # Cross entropy loss l1, l1_ = cl_loss(cl1, lb1), cl_loss(cl1_, lb1) l2, l2_ = cl_loss(cl2, lb2), cl_loss(cl2_, lb2) l3, l3_ = cl_loss(cl3, lb3), cl_loss(cl3_, lb3) # total loss s1loss = l1 + l2 + l3 + l1_ + l2_ + l3_ + opt.alpha * loss_mom s1loss.backward() optimizer_features.step() optimizer_classifier.step() optimizer_classifier_.step() ### STEP 2 ### fix G, and train C pairs optimizer_classifier.zero_grad() optimizer_classifier_.zero_grad() # Class Prediction on each src domain cl1, cl1_ = classifier_1(ft1.detach()), classifier_1_(ft1.detach()) cl2, cl2_ = classifier_2(ft2.detach()), classifier_2_(ft2.detach()) cl3, cl3_ = classifier_3(ft3.detach()), classifier_3_(ft3.detach()) # discrepancy on tgt domain clt1, clt1_ = classifier_1(ft_t.detach()), classifier_1_(ft_t.detach()) clt2, clt2_ = classifier_2(ft_t.detach()), classifier_2_(ft_t.detach()) clt3, clt3_ = classifier_3(ft_t.detach()), classifier_3_(ft_t.detach()) # classification loss l1, l1_ = cl_loss(cl1, lb1), cl_loss(cl1_, lb1) l2, l2_ = cl_loss(cl2, lb2), cl_loss(cl2_, lb2) l3, l3_ = cl_loss(cl3, lb3), cl_loss(cl3_, lb3) # print(clt1.shape) dl1 = disc_loss(clt1, clt1_) dl2 = disc_loss(clt2, clt2_) dl3 = disc_loss(clt3, clt3_) # print(dl1, dl2, dl3) # backward s2loss = l1 + l2 + l3 + l1_ + l2_ + l3_ - dl1 - dl2 - dl3 s2loss.backward() optimizer_classifier.step() optimizer_classifier_.step() ### STEP 3 #### fix C pairs, train G optimizer_features.zero_grad() ft_t = feature_extractor(imgt) clt1, clt1_ = classifier_1(ft_t), classifier_1_(ft_t) clt2, clt2_ = classifier_2(ft_t), classifier_2_(ft_t) clt3, clt3_ = classifier_3(ft_t), classifier_3_(ft_t) dl1 = disc_loss(clt1, clt1_) dl2 = disc_loss(clt2, clt2_) dl3 = disc_loss(clt3, clt3_) s3loss = dl1 + dl2 + dl3 s3loss.backward() optimizer_features.step() pred = torch.stack((cl1, cl2, cl3, cl1_, cl2_, cl3_), 0) # [6, bs, 345] _, pred = torch.max(pred, dim = 2) # [6, bs] gt = torch.stack((lb1, lb2, lb3, lb1, lb2, lb3), 0) # [6, bs] correct = pred.eq(gt.data) correct = torch.mean(correct.type(torch.FloatTensor), dim = 1).cpu().numpy() tot_loss += s1loss.item() * cur_batch tot_clf_loss += (s1loss.item() - opt.alpha * loss_mom.item()) * cur_batch tot_s2_loss += s2loss.item() * cur_batch tot_s3_loss += s3loss.item() * cur_batch tot_mom_loss += loss_mom.item() * cur_batch tot_correct += correct * cur_batch n_samples += cur_batch # print(cur_batch) if iteration % opt.log_interval == 0: current_time = time.time() print('Train Epoch: {} [{}/{} ({:.0f}%)]\tClfLoss: {:.4f}\tMMLoss: {:.4f}\t \ S2Loss: {:.4f}\tS3Loss: {:.4f}\t \ Accu: {:.4f}\\{:.4f}\\{:.4f}\\{:.4f}\\{:.4f}\\{:.4f}\tTime: {:.3f}'.format(\ epoch, i * opt.bs, len_data, 100. * i / len_bs, \ tot_clf_loss / n_samples, tot_mom_loss / n_samples, tot_s2_loss / n_samples, tot_s3_loss / n_samples, tot_correct[0] / n_samples, tot_correct[1] / n_samples, tot_correct[2] / n_samples, tot_correct[3] / n_samples, tot_correct[4] / n_samples, tot_correct[5] / n_samples, current_time - saved_time)) writer.add_scalar('Train/ClfLoss', tot_clf_loss / n_samples, iteration * opt.bs) writer.add_scalar('Train/MMLoss', tot_mom_loss / n_samples, iteration * opt.bs) writer.add_scalar('Train/s2Loss', tot_s2_loss / n_samples, iteration * opt.bs) writer.add_scalar('Train/s3Loss', tot_s3_loss / n_samples, iteration * opt.bs) writer.add_scalar('Train/Accu0', tot_correct[0] / n_samples, iteration * opt.bs) writer.add_scalar('Train/Accu1', tot_correct[1] / n_samples, iteration * opt.bs) writer.add_scalar('Train/Accu2', tot_correct[2] / n_samples, iteration * opt.bs) writer.add_scalar('Train/Accu0_', tot_correct[3] / n_samples, iteration * opt.bs) writer.add_scalar('Train/Accu1_', tot_correct[4] / n_samples, iteration * opt.bs) writer.add_scalar('Train/Accu2_', tot_correct[5] / n_samples, iteration * opt.bs) saved_weight = torch.FloatTensor([tot_correct[0], tot_correct[1], tot_correct[2], tot_correct[3], tot_correct[4], tot_correct[5]]).to(device) if torch.sum(saved_weight) == 0.: saved_weight = torch.FloatTensor(6).to(device).fill_(1)/6. else: saved_weight = saved_weight/torch.sum(saved_weight) saved_time = time.time() tot_clf_loss, tot_mom_loss, tot_correct, n_samples = 0, 0, [0, 0, 0, 0, 0, 0], 0 tot_s2_loss, tot_s3_loss = 0, 0 train_loss.append(tot_loss) # evaluation and save if iteration % opt.eval_interval == 0 and iteration >= 0 and target != 'real': print('weight = ', saved_weight.cpu().numpy()) evalacc = eval(saved_weight, feature_extractor, classifier_1_, classifier_2_, classifier_3_, classifier_1, classifier_2, classifier_3, dataloader_tt) writer.add_scalar('Test/Accu', evalacc, iteration * opt.bs) acc_on_target.append(evalacc) print('Eval Acc = {:.2f}\n'.format(evalacc*100)) torch.save({ 'epoch': epoch, 'feature_extractor': feature_extractor.state_dict(), '{}_classifier'.format(source1): classifier_1.state_dict(), '{}_classifier'.format(source2): classifier_2.state_dict(), '{}_classifier'.format(source3): classifier_3.state_dict(), '{}_classifier_'.format(source1): classifier_1_.state_dict(), '{}_classifier_'.format(source2): classifier_2_.state_dict(), '{}_classifier_'.format(source3): classifier_3_.state_dict(), 'features_optimizer': optimizer_features.state_dict(), 'classifier_optimizer': optimizer_classifier.state_dict(), 'loss': tot_loss, 'saved_weight': saved_weight }, os.path.join(path_output, target + '-{}-{:.2f}.pth'.format(epoch, evalacc*100))) iteration += 1 pkl.dump(train_loss, open('{}train_loss.p'.format(path_output), 'wb')) if target != 'real': pkl.dump(acc_on_target, open('{}target_accuracy.p'.format(path_output), 'wb'))
def main(): root = 'data/' source1 = 'real' source2 = 'infograph' source3 = 'quickdraw' target = 'sketch' adaptive_weight = True if not target == 'real': dataset_t = DA(dir=root, name=target, img_size=(224, 224), train=False) else: dataset_t = test_dataset(dir='data/test', img_size=(224, 224)) dataloader_t = DataLoader(dataset_t, batch_size=64, shuffle=False, num_workers=8) path = 'checkpoints/infograph-0-deming.pth' #you may change the path 'checkpoints/sketch-30.pth' feature_extractor = models.feature_extractor() classifier_1 = models.class_classifier() classifier_2 = models.class_classifier() classifier_3 = models.class_classifier() state = torch.load(path) print(len(state)) print(state.keys()) print() feature_extractor.load_state_dict(state['feature_extractor']) classifier_1.load_state_dict(state['{}_classifier'.format(source1)]) classifier_2.load_state_dict(state['{}_classifier'.format(source2)]) classifier_3.load_state_dict(state['{}_classifier'.format(source3)]) if adaptive_weight: w1_mean = state['{}_weight'.format(source1)] w2_mean = state['{}_weight'.format(source2)] w3_mean = state['{}_weight'.format(source3)] else: w1_mean = 1 / 3 w2_mean = 1 / 3 w3_mean = 1 / 3 if torch.cuda.is_available(): feature_extractor = feature_extractor.cuda() classifier_1 = classifier_1.cuda() classifier_2 = classifier_2.cuda() classifier_3 = classifier_3.cuda() feature_extractor.eval() classifier_1.eval(), classifier_2.eval(), classifier_3.eval() ans = open('{}_pred.csv'.format(target), 'w') ans.write('image_name,label\n') m = nn.Softmax(1) with torch.no_grad(): for idx, (img, name) in enumerate(dataloader_t): if torch.cuda.is_available(): img = img.cuda() ft_t = feature_extractor(img) pred1 = classifier_1(ft_t) pred2 = classifier_2(ft_t) pred3 = classifier_3(ft_t) pred = (pred1 * w1_mean + pred2 * w2_mean + pred3 * w3_mean) pred = m(pred) #embed() _, pred = torch.max(pred, dim=1) print('\r Predicting... Progress: %.1f %%' % (100 * (idx + 1) / len(dataloader_t)), end='') for i in range(len(name)): ans.write('{},{}\n'.format(os.path.join('test/', name[i]), pred[i])) ans.close()