def load_data(): img_transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset_source = datasets.MNIST( root=IMG_DIR_SRC, train=True, transform=img_transform, download=True ) dataloader_source = torch.utils.data.DataLoader( dataset=dataset_source, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=8) train_list = IMG_DIR_TAR + '/mnist_m_train_labels.txt' dataset_target = GetLoader( data_root=IMG_DIR_TAR + '/mnist_m_train', data_list=train_list, transform=img_transform ) dataloader_target = torch.utils.data.DataLoader( dataset=dataset_target, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=8) return dataloader_source, dataloader_target
def load_test_data(dataset_name): img_transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) if dataset_name == 'mnist_m': test_list = '../dataset/mnist_m/mnist_m_test_labels.txt' dataset = GetLoader( data_root='../dataset/mnist_m/mnist_m_test', data_list=test_list, transform=img_transform ) else: dataset = datasets.MNIST( root=IMG_DIR_SRC, train=False, transform=img_transform, download=True ) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8 ) return dataloader
def load_data(): if DATASET_NAME == 'cifar': img_transform_cifar = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset = datasets.CIFAR10(root='CIFAR', train=True, transform=img_transform_cifar, target_transform=None, download=True) elif DATASET_NAME == 'gtsrb': img_transform_gtrsb = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.3403, 0.3121, 0.3214), (0.2724, 0.2608, 0.2669)) ]) dataset = gtsrb_dataset.GTSRB(root_dir='./', train=True, transform=img_transform_gtrsb) elif DATASET_NAME == 'mnist': img_transform_mnist = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5) ]) dataset = datasets.MNIST(root='./', train=True, transform=img_transform_mnist, download=True) elif DATASET_NAME == 'mnistm': img_transform_mnist = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5) ]) train_list = './mnist_m/mnist_m_train_labels.txt' dataset = GetLoader(data_root='./mnist_m/mnist_m_train', data_list=train_list, transform=img_transform_mnist) else: print('Data not found.') exit() return dataset
]) dataset_source = datasets.MNIST(root='dataset', train=True, transform=img_transform_source, download=True) dataloader_source = torch.utils.data.DataLoader(dataset=dataset_source, batch_size=batch_size, shuffle=True, num_workers=8) train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt') dataset_target = GetLoader(data_root=os.path.join(target_image_root, 'mnist_m_train'), data_list=train_list, transform=img_transform_target) dataloader_target = torch.utils.data.DataLoader(dataset=dataset_target, batch_size=batch_size, shuffle=True, num_workers=8) # load model my_net = CNNModel() # setup optimizer optimizer = optim.Adam(my_net.parameters(), lr=lr)
transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]) img_transform_target = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) if dataset_name == 'mnist_m': test_list = os.path.join(image_root, 'mnist_m_test_labels.txt') dataset = GetLoader(data_root=os.path.join(image_root, 'mnist_m_test'), data_list=test_list, transform=img_transform_target) else: dataset = datasets.MNIST( root='dataset', train=False, transform=img_transform_source, ) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8) """ test """ my_net = torch.load(
def test(epoch, name): ################### # params # ################### cuda = True cudnn.benchmark = True batch_size = 64 image_size = 28 ################### # load data # ################### img_transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) model_root = 'model' if name == 'mnist': mode = 'source' image_root = os.path.join('dataset', 'mnist') dataset = datasets.MNIST(root=image_root, train=False, transform=img_transform) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8) elif name == 'mnist_m': mode = 'target' image_root = os.path.join('dataset', 'mnist_m', 'mnist_m_test') test_list = os.path.join('dataset', 'mnist_m', 'mnist_m_test_labels.txt') dataset = GetLoader(data_root=image_root, data_list=test_list, transform=img_transform) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8) else: print 'error dataset name' #################### # load model # #################### my_net = DSN() checkpoint = torch.load( os.path.join(model_root, 'dsn_mnist_mnistm_epoch_' + str(epoch) + '.pth')) my_net.load_state_dict(checkpoint) my_net.eval() if cuda: my_net = my_net.cuda() #################### # transform image # #################### def tr_image(img): img_new = (img + 1) / 2 return img_new len_dataloader = len(dataloader) data_iter = iter(dataloader) i = 0 n_total = 0 n_correct = 0 while i < len_dataloader: data_input = data_iter.next() img, label = data_input batch_size = len(label) input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) class_label = torch.LongTensor(batch_size) if cuda: img = img.cuda() label = label.cuda() input_img = input_img.cuda() class_label = class_label.cuda() input_img.resize_as_(input_img).copy_(img) class_label.resize_as_(label).copy_(label) inputv_img = Variable(input_img) classv_label = Variable(class_label) result = my_net(input_data=inputv_img, mode='source', rec_scheme='share') pred = result[3].data.max(1, keepdim=True)[1] result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all') rec_img_all = tr_image(result[-1].data) result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share') rec_img_share = tr_image(result[-1].data) result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private') rec_img_private = tr_image(result[-1].data) if i == len_dataloader - 2: vutils.save_image(rec_img_all, name + '_rec_image_all.png', nrow=8) vutils.save_image(rec_img_share, name + '_rec_image_share.png', nrow=8) vutils.save_image(rec_img_private, name + '_rec_image_private.png', nrow=8) n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum() n_total += batch_size i += 1 accu = n_correct * 1.0 / n_total print 'epoch: %d, accuracy of the %s dataset: %f' % (epoch, name, accu)
def test(dataset_name): assert dataset_name in ['MNIST', 'mnist_m'] model_root = 'models' image_root = os.path.join('dataset', dataset_name) cuda = True cudnn.benchmark = True batch_size = 128 image_size = 28 alpha = 0 """load data""" img_transform_source = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]) img_transform_target = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) if dataset_name == 'mnist_m': test_list = os.path.join(image_root, 'mnist_m_test_labels.txt') dataset = GetLoader(data_root=os.path.join(image_root, 'mnist_m_test'), data_list=test_list, transform=img_transform_target) else: dataset = datasets.MNIST( root='dataset', train=False, transform=img_transform_source, ) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8) """ test """ my_net = torch.load( os.path.join(model_root, 'mnist_mnistm_model_epoch_current.pth')) my_net = my_net.eval() if cuda: my_net = my_net.cuda() len_dataloader = len(dataloader) data_target_iter = iter(dataloader) i = 0 n_total = 0 n_correct = 0 while i < len_dataloader: # test model using target data data_target = data_target_iter.next() t_img, t_label = data_target batch_size = len(t_label) if cuda: t_img = t_img.cuda() t_label = t_label.cuda() class_output, _ = my_net(input_data=t_img, alpha=alpha) pred = class_output.data.max(1, keepdim=True)[1] n_correct += pred.eq(t_label.data.view_as(pred)).cpu().sum() n_total += batch_size i += 1 accu = n_correct.data.numpy() * 1.0 / n_total return accu
def run(net_str): # execute only if run as the entry point into the program # 定义源域和当前目标域 net_str = os.path.join( 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2', net_str) source_image_root = os.path.join('D:\\', 'study', 'graduation_project', 'grdaution_project', 'instru_identify', 'dataset', 'dataset1') target_image_root = os.path.join('D:\\', 'study', 'graduation_project', 'grdaution_project', 'instru_identify', 'dataset', 'dataset2') target = 'dataset2' # 选取历史数据的比例 p = str(8) # 模型保存路径 model_root = 'dataset1' + p + 'dataset2' if not os.path.exists(model_root): os.mkdir(model_root) if not os.path.exists(model_root): os.makedirs(model_root) # 训练日志保存 log_path = os.path.join(model_root, 'train.txt') sys.stdout = Logger(log_path) # print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) # 训练参数定义 cuda = False cudnn.benchmark = True lr = 1e-2 batch_size = 16 image_size = 28 n_epoch = 1 step_decay_weight = 0.95 lr_decay_step = 20000 active_domain_loss_step = 10000 weight_decay = 1e-6 alpha_weight = 0.01 beta_weight = 0.075 gamma_weight = 0.25 momentum = 0.9 manual_seed = random.randint(1, 10000) random.seed(manual_seed) torch.manual_seed(manual_seed) ####################### # load data # ####################### img_transform_source = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]) img_transform_target = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # 源域数据加载 source_list = os.path.join(source_image_root, 'dataset1_train_labels.txt') dataset_source = GetLoader( data_root=os.path.join(source_image_root, 'dataset1_train'), data_list=source_list, transform=img_transform_target, ) dataloader_source = torch.utils.data.DataLoader( dataset=dataset_source, batch_size=batch_size, shuffle=True, # 随机数种子 num_workers=0 # 进程数 ) # 目标域数据加载 target_list = os.path.join(target_image_root, 'dataset2_train_labels.txt') dataset_target = GetLoader( data_root=os.path.join(target_image_root, 'dataset2_train'), data_list=target_list, transform=img_transform_target, ) dataloader_target = torch.utils.data.DataLoader( dataset=dataset_target, batch_size=batch_size, shuffle=True, num_workers=0, # 单进程加载 ) ##################### # load model # ##################### my_net = DSN() my_net.load_state_dict(torch.load(net_str)) ##################### # setup optimizer # ##################### def exp_lr_scheduler(optimizer, step, init_lr=lr, lr_decay_step=lr_decay_step, step_decay_weight=step_decay_weight): # Decay learning rate by a factor of step_decay_weight every lr_decay_step current_lr = init_lr * (step_decay_weight**(step / lr_decay_step)) if step % lr_decay_step == 0: print('learning rate is set to %f' % current_lr) for param_group in optimizer.param_groups: param_group['lr'] = current_lr return optimizer optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) # 损失函数定义 loss_classfication = torch.nn.CrossEntropyLoss() loss_recon1 = MSE() loss_recon2 = SIMSE() loss_diff = DiffLoss_tfTrans() loss_similarity = torch.nn.CrossEntropyLoss() if cuda: my_net = my_net.cuda() loss_classification = loss_classification.cuda() loss_recon1 = loss_recon1.cuda() loss_recon2 = loss_recon2.cuda() loss_diff = loss_diff.cuda() loss_similarity = loss_similarity.cuda() for p in my_net.parameters(): p.requires_grad = True ############################# # training network # ############################# # 获取最短数据长度 len_dataloader = min(len(dataloader_source), len(dataloader_target)) # 设置epoch dann_epoch = np.floor(active_domain_loss_step / len_dataloader * 1.0) current_step = 0 # 开始训练 accu_total1 = 0 # 统计dataset1中的总准确率和 accu_total2 = 0 # 统计dataset2中的总准确率和 time_total1 = 0 # 统计dataset1训练的总时间 time_total2 = 0 # 统计dataset2训练的总时间 for epoch in range(n_epoch): # 1.加载数据 data_source_iter = iter(dataloader_source) data_target_iter = iter(dataloader_target) i = 0 # 防止数据超过最短数据长度,否则可能由于缺失某些数据出现报错 while i < len_dataloader: ######################## # target data training # ######################## # 加载target data_target = data_target_iter.next() t_img, t_label = data_target # 1.梯度清零 my_net.zero_grad() loss = 0 batch_size = len(t_label) # 2.初始化一些变量 input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) class_label = torch.LongTensor(batch_size) domain_label = torch.ones(batch_size) domain_label = domain_label.long() # 判断gpu是否可用,如果可用,就将数据传入cuda中 if cuda: t_img = t_img.cuda() t_label = t_label.cuda() input_img = input_img.cuda() class_label = class_label.cuda() domain_label = domain_label.cuda() # 将一部分数据resize,并拷贝到上面设置的变量 input_img.resize_as_(t_img).copy_(t_img) class_label.resize_as_(t_label).copy_(t_label) target_inputv_img = Variable(input_img) target_classv_label = Variable(class_label) target_domainv_label = Variable(domain_label) # 论文中涉及到的公式 if current_step > active_domain_loss_step: p = float(i + (epoch - dann_epoch) * len_dataloader / (n_epoch - dann_epoch) / len_dataloader) p = 2. / (1. + np.exp(-10 * p)) - 1 # active domain loss # 这一步就是将输入输入到模型中,然后得到模型的结果 result = my_net(input_data=target_inputv_img, mode='target', rec_scheme='all', p=p) target_private_coda, target_share_coda, target_domain_label, target_rec_code = result # 通过python拆包得到的几个变量 target_dann = gamma_weight * loss_similarity( target_domain_label, target_domainv_label) # 4.计算损失值 loss += target_dann # 计算累计损失值 else: if cuda: target_dann = Variable(torch.zeros(1).float().cuda()) # ? else: target_dann = Variable(torch.zeros(1).float()) # 将输入传到模型中,然后得到模型结果 result = my_net(input_data=target_inputv_img, mode='target', rec_scheme='all') target_private_coda, target_share_coda, _, target_rec_code = result # 通过python的拆包得到几个变量 # 以下几步用于计算损失值 target_diff = beta_weight * loss_diff( target_private_coda, target_share_coda, weight=0.05) loss += target_diff target_mse = alpha_weight * loss_recon1( target_rec_code, target_inputv_img) loss += target_mse target_simse = alpha_weight * loss_recon2( target_rec_code, target_inputv_img) loss += target_mse # 5.计算梯度 loss.backward() # 6.利用梯度优化权重和偏置等网络参数 # optimizer = exp_lr_scheduler(optimizer=optimizer,step = current_step) optimizer.step() ####################### # source data training# ####################### data_source = data_source_iter.next() s_img, s_label = data_source my_net.zero_grad() batch_size = len(s_label) input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) class_label = torch.LongTensor(batch_size) domain_label = torch.zeros(batch_size) damain_label = domain_label.long() loss = 0 if cuda: s_img = s_img.cuda() s_label = s_label.cuda() input_img = input_img.cuda() class_label = class_label.cuda() domain_label = domain_label.cuda() input_img.resize_as_(input_img).copy_(s_img) class_label.resize_as_(s_label).copy_(s_label) source_inputv_img = Variable(input_img) source_classv_label = Variable(class_label) source_domainv_label = Variable(domain_label) if current_step > active_domain_loss_step: # active domain loss # 输入模型进行训练 result = my_net(input_data=source_inputv_img, mode='source', rec_scheme='all', p=p) source_private_code, source_share_code, source_domain_label, source_classv_label, source_rec_code = result source_dann = gamma_weight * loss_similarity( source_domain_label, source_classv_label) loss += source_dann else: if cuda: source_dann = Variable(torch.zeros(1).float().cuda()) else: if cuda: source_dann = Variable( torch.zeros(1).float().cuda()) else: source_dann = Variable(torch.zeros(1).float()) result = my_net(input_data=source_inputv_img, mode='source', rec_scheme='all') source_private_code, source_share_code, _, source_class_label, source_rec_code = result source_classification = loss_classfication( source_class_label, source_classv_label) loss += source_classification source_diff = beta_weight * loss_diff( source_private_code, source_share_code, weight=0.05) loss += source_diff source_mse = alpha_weight * loss_recon1( source_rec_code, source_inputv_img) loss += source_mse source_simse = gamma_weight * loss_recon2( source_rec_code, source_inputv_img) loss += source_simse loss.backward() # optimizer = exp_lr_scheduler(optimizer=optimizer,step=current_step) optimizer.step() ############## # 测试保存 # ############## i += 1 current_step += 1 # print('source_classification: %f, source_dann: %f, source_diff: %f, '\ # 'source_mse: %f, source_simse: %f, target_dann: %f, target_diff: %f, '\ # 'target_mse: %f, target_simse: %f' \ # % (source_classification.data.cpu().numpy(), source_dann.data.cpu().numpy(), # source_diff.data.cpu().numpy(), # source_mse.data.cpu().numpy(), source_simse.data.cpu().numpy(), target_dann.data.cpu().numpy(), # target_diff.data.cpu().numpy(), target_mse.data.cpu().numpy(), target_simse.data.cpu().numpy())) # 训练数据集1并计算累积时间,和累积准确率 start1 = time.time() accu1 = test(epoch=epoch, name='dataset1') end1 = time.time() curr1 = end1 - start1 time_total1 += curr1 accu_total1 += accu1 # 训练数据集2并计算累积时间,和累积准确率 start2 = time.time() accu2 = test(epoch=epoch, name='dataset2') end2 = time.time() curr2 = end2 - start2 time_total2 += curr2 accu_total2 += accu2 # print(time.strftime('%Y-%m-%d %H:%M:%S'), time.localtime(time.time())) # 获取平均准确率做为训练性能的评价指标 model_index = epoch # 获取模型保存路径 model_path = 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2' + '\dsn_epoch_' + str( model_index) + '.pth' while os.path.exists(model_path): model_index = model_index + 1 model_path = 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2' + '\dsn_epoch_' + str( model_index) + '.pth' torch.save(my_net.state_dict(), model_path) # 保存模型 average_accu1 = accu_total1 / (len_dataloader * n_epoch) average_accu2 = accu_total2 / (len_dataloader * n_epoch) # result = [float(average_accu1),float(average_accu2)] # 所有数据均保留三位小数进行存储 print(round(float(average_accu1), 3)) print(round(float(average_accu2), 3)) print(round(float(time_total1), 3)) print(round(float(time_total2), 3)) # print('result:',result) return result
def test(epoch, name): cuda = False cudnn.benchmark = True batch_size = 16 image_size = 28 p = str(8) model_root = 'dataset1' + p + 'dataset2' ################ # load data # ################ # 图形变换 source_img_transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), # 归一化,进行图像的灰度处理 transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # 单通道变为三通道 transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) img_transform_source = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) #? ]) img_tranform_target = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) if name == 'dataset1': mode = 'source' image_root = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test' # image_root.replace("\\",'/') test_list = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test_labels.txt' dataset = GetLoader( data_root=image_root, data_list=test_list, transform=img_transform_source, ) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=0) # print('success') elif name == 'dataset2': mode = 'target' image_root = os.path.join('D:\\', 'study', 'graduation_project', 'grdaution_project', 'instru_identify', 'dataset', 'dataset2', 'dataset2_test') test_list = os.path.join('D:\\', 'study', 'graduation_project', 'grdaution_project', 'instru_identify', 'dataset', 'dataset2', 'dataset2_test_labels.txt') dataset = GetLoader( data_root=image_root, data_list=test_list, transform=img_tranform_target, ) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=0 #? ) else: print('error dataset name') ############### # load model # ############### # print('image_root:', image_root) # print('test_list:',test_list) my_net = DSN() checkpoint = torch.load( os.path.join(model_root, 'dsn_epoch_' + str(epoch) + '.pth')) my_net.load_state_dict(checkpoint) my_net.eval() #? if cuda: my_net = my_net #.cuda() ################### # transform image # ################### # 这个函数对图片做了什么操作? def tr_image(img): img_new = (img + 1) / 2 return img_new # print(dataloader) len_dataloader = len(dataloader) # print('len_dataloader:',len_dataloader) data_iter = iter(dataloader) # 获取迭代器 # print('data_iter:',data_iter) i = 0 n_total = 0 n_correct = 0 total_accu = 0 while i < len_dataloader - 1: #print(i) data_input = data_iter.next() #print('data_input:', data_input) img, label = data_input # print('label:', label) batch_size = len(label) # batch_size为一个batch中图片的数量 input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) class_label = torch.LongTensor(batch_size) if cuda: img = img #.cuda() label = label #.cuda() input_img = input_img #.cuda() class_label = class_label #.cuda() input_img.resize_as_(input_img).copy_(img) class_label.resize_as_(class_label).copy_(label) inputv_img = Variable(input_img) #? classv_label = Variable(class_label) # 输入网络 result = my_net(input_data=inputv_img, mode='source', rec_scheme='share') pred = result[3].data.max(1, keepdim=True)[1] # print('pred:',pred) result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all') rec_img_all = tr_image(result[-1].data) result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share') rec_img_share = tr_image(result[-1].data) result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private') rec_img_private = tr_image(result[-1].data) if i == len_dataloader - 2: image_save_path = os.path.join(model_root, 'images') if not os.path.exists(image_save_path): os.mkdir(image_save_path) vutils.save_image(rec_img_all, image_save_path + '/' + name + '_rec_image_all.png', nrow=8) vutils.save_image(rec_img_share, image_save_path + '/' + name + 'rec_image_share.png', nrow=8) vutils.save_image(rec_img_private, image_save_path + '/' + name + 'rec_image_private.png', nrow=8) n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum() n_total += batch_size i += 1 accu = n_correct * 1.0 / n_total # print('n_correct:', n_correct) # print('n_total:', n_total) # print('epoch: %d,accuracy of the %s dataset: %f' % (epoch, name, accu)) return accu
transforms.Resize((28,28)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # dataset_source = datasets.MNIST( # root='dataset', # train=True, # transform=img_transform_source, # download=True # ) source_list = os.path.join(source_image_root, 'image_label.txt') dataset_source = GetLoader( data_root='/root/Data/source/', data_list=source_list, transform=img_transform_source ) dataloader_source = torch.utils.data.DataLoader( dataset=dataset_source, batch_size=batch_size, shuffle=True, num_workers=4) train_list = os.path.join(target_image_root, 'image_label.txt') dataset_target = GetLoader( data_root='/root/Data/target/', data_list=train_list, transform=img_transform_target, )
def run(args): args.logdir = args.logdir + args.mode args.trained = args.trained + args.mode + '/best_model.pt' if not os.path.exists(args.logdir): os.makedirs(args.logdir) logger = get_logger(os.path.join(args.logdir, 'main.log')) logger.info(args) # data # source_transform = transforms.Compose([ # # transforms.Grayscale(), # transforms.ToTensor()] # ) # target_transform = transforms.Compose([ # transforms.Resize(32), # transforms.ToTensor(), # transforms.Lambda(lambda x: x.repeat(3, 1, 1)) # ]) # source_dataset_train = SVHN( # './input', 'train', transform=source_transform, download=True) # target_dataset_train = MNIST( # './input', train=True, transform=target_transform, download=True) # target_dataset_test = MNIST( # './input', train=False, transform=target_transform, download=True) # source_train_loader = DataLoader( # source_dataset_train, args.batch_size, shuffle=True, # drop_last=True, # num_workers=args.n_workers) # target_train_loader = DataLoader( # target_dataset_train, args.batch_size, shuffle=True, # drop_last=True, # num_workers=args.n_workers) # target_test_loader = DataLoader( # target_dataset_test, args.batch_size, shuffle=False, # num_workers=args.n_workers) batch_size = 128 if args.mode == 'm2mm': source_dataset_name = 'MNIST' target_dataset_name = 'mnist_m' source_image_root = os.path.join('dataset', source_dataset_name) target_image_root = os.path.join('dataset', target_dataset_name) image_size = 28 img_transform_source = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]) img_transform_target = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset_source = datasets.MNIST(root='dataset', train=True, transform=img_transform_source, download=True) train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt') dataset_target_train = GetLoader(data_root=os.path.join( target_image_root, 'mnist_m_train'), data_list=train_list, transform=img_transform_target) test_list = os.path.join(target_image_root, 'mnist_m_test_labels.txt') dataset_target_test = GetLoader(data_root=os.path.join( target_image_root, 'mnist_m_test'), data_list=test_list, transform=img_transform_target) elif args.mode == 's2u': dataset_source = svhn.SVHN('./data/svhn/', split='train', download=True, transform=transforms.Compose([ transforms.Resize(28), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) dataset_target_train = usps.USPS('./data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) dataset_target_test = usps.USPS('./data/usps/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) source_dataset_name = 'svhn' target_dataset_name = 'usps' source_train_loader = torch.utils.data.DataLoader(dataset=dataset_source, batch_size=batch_size, shuffle=True, num_workers=8) target_train_loader = torch.utils.data.DataLoader( dataset=dataset_target_train, batch_size=batch_size, shuffle=True, num_workers=8) target_test_loader = torch.utils.data.DataLoader( dataset=dataset_target_test, batch_size=batch_size, shuffle=False, num_workers=8) # train source CNN source_cnn = CNN(in_channels=args.in_channels).to(args.device) if os.path.isfile(args.trained): print("load model") c = torch.load(args.trained) source_cnn.load_state_dict(c['model']) logger.info('Loaded `{}`'.format(args.trained)) else: print("not load model") # train target CNN target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device) target_cnn.load_state_dict(source_cnn.state_dict()) discriminator = Discriminator(args=args).to(args.device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(target_cnn.encoder.parameters(), lr=args.lr) # optimizer = optim.Adam( # target_cnn.encoder.parameters(), # lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr) # d_optimizer = optim.Adam( # discriminator.parameters(), # lr=args.lr, betas=args.betas, weight_decay=args.weight_decay) train_target_cnn(source_cnn, target_cnn, discriminator, criterion, optimizer, d_optimizer, source_train_loader, target_train_loader, target_test_loader, args=args)
def test(dataset_name, epoch): assert dataset_name in ['source', 'target'] model_root = 'models' image_root = os.path.join('/root/Data', dataset_name) cuda = True cudnn.benchmark = True batch_size = 128 image_size = 28 alpha = 0 """load data""" img_transform_source = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]) img_transform_target = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # if dataset_name == 'mnist_m': # test_list = os.path.join(image_root, 'mnist_m_test_labels.txt') # # dataset = GetLoader( # data_root=os.path.join(image_root, 'mnist_m_test'), # data_list=test_list, # transform=img_transform_target # ) # else: # dataset = datasets.MNIST( # root='dataset', # train=False, # transform=img_transform_source, # ) target_list = os.path.join(image_root, 'image_label.txt') dataset = GetLoader(data_root=image_root, data_list=target_list, transform=img_transform_target) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2) """ training """ my_net = torch.load( os.path.join(model_root, 'mnist_mnistm_model_epoch_' + str(epoch) + '.pth')) my_net = my_net.eval() if cuda: my_net = my_net.cuda() len_dataloader = len(dataloader) data_target_iter = iter(dataloader) i = 0 n_total = 0 n_correct = 0 num_class = 15 acc_class = [0 for _ in range(num_class)] count_class = [0 for _ in range(num_class)] tsne_results = np.array([]) tsne_labels = np.array([]) while i < len_dataloader: # test model using target data data_target = data_target_iter.next() t_img, t_label = data_target batch_size = len(t_label) input_img = torch.FloatTensor(batch_size, 3, image_size, image_size) class_label = torch.LongTensor(batch_size) if cuda: t_img = t_img.cuda() t_label = t_label.cuda() input_img = input_img.cuda() class_label = class_label.cuda() input_img.resize_as_(t_img).copy_(t_img) class_label.resize_as_(t_label).copy_(t_label) class_output, _ = my_net(input_data=input_img, alpha=alpha) pred = class_output.data.max(1, keepdim=True)[1] pred1 = class_output.data.max(1)[1] n_correct += pred.eq(class_label.data.view_as(pred)).cpu().sum() n_total += batch_size i += 1 index_temp = pred1.eq(t_label.data) for acc_index in range(batch_size): temp_label_index = t_label.data[acc_index] count_class[temp_label_index] += 1 if index_temp[acc_index]: acc_class[temp_label_index] += 1 if len(tsne_labels) == 0: tsne_results = class_output.cpu().data.numpy() tsne_labels = t_label.cpu().numpy() else: tsne_results = np.concatenate( (tsne_results, class_output.cpu().data.numpy())) tsne_labels = np.concatenate((tsne_labels, t_label.cpu().numpy())) plot_only = 1000 tsne_model = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) tsne_transformed = tsne_model.fit_transform(tsne_results[:plot_only, :]) tsne_labels = tsne_labels[:plot_only] # colors = cm.rainbow(np.linspace(0, 1, num_class)) for x, y, s in zip(tsne_transformed[:, 0], tsne_transformed[:, 1], tsne_labels): c = cm.rainbow(int(255 * s / num_class)) plt.scatter(x, y, c=c) plt.xticks([]) plt.yticks([]) plt.savefig('output1.png') for print_index in range(len(acc_class)): print('Class:{}, Accuracy:{:.2f}%'.format( print_index, 100. * acc_class[print_index] / count_class[print_index])) accu = n_correct.data.numpy() * 1.0 / n_total print('epoch: %d, accuracy of the %s dataset: %f' % (epoch, dataset_name, accu)) torch.save( accu, '/root/Data/dann_result/dann_ep_' + str(epoch) + '_' + dataset_name + '_' + str(accu) + '.pt')