def build_models(self): ################### generator ######################################### self.G = SingleGenerator(input_nc=self.opt.input_nc, output_nc=self.opt.input_nc, ngf=self.opt.ngf, nc=self.opt.c_num+self.opt.d_num, e_blocks=self.opt.e_blocks, norm_type=self.opt.norm) ################### encoder ########################################### self.E =None if self.opt.mode == 'multimodal': self.E = Encoder(input_nc=self.opt.input_nc, output_nc=self.opt.c_num, nef=self.opt.nef, nd=self.opt.d_num, n_blocks=4, norm_type=self.opt.norm) if self.opt.isTrain: ################### discriminators ##################################### self.Ds = [] for i in range(self.opt.d_num): self.Ds.append(D_NET_Multi(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=3,norm_type=self.opt.norm)) ################### init_weights ######################################## if self.opt.continue_train: self.G.load_state_dict(torch.load('{}/G_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch))) if self.E is not None: self.E.load_state_dict(torch.load('{}/E_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch))) for i in range(self.opt.d_num): self.Ds[i].load_state_dict(torch.load('{}/D_{}_{}.pth'.format(self.opt.model_dir, i, self.opt.which_epoch))) else: self.G.apply(weights_init(self.opt.init_type)) if self.E is not None: self.E.apply(weights_init(self.opt.init_type)) for i in range(self.opt.d_num): self.Ds[i].apply(weights_init(self.opt.init_type)) ################### use GPU ############################################# self.G.cuda() if self.E is not None: self.E.cuda() for i in range(self.opt.d_num): self.Ds[i].cuda() ################### set criterion ######################################## self.criterionGAN = GANLoss(mse_loss=(self.opt.c_gan_mode == 'lsgan')) ################## define optimizers ##################################### self.define_optimizers() else: self.G.load_state_dict(torch.load('{}/G_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch))) self.G.cuda() self.G.eval() if self.E is not None: self.E.load_state_dict(torch.load('{}/E_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch))) self.E.cuda() self.E.eval()
def main(opt): opt.digitroot = _DIGIT_ROOT if opt.prefix == '': opt.prefix = _PREFIX if opt.model == '': opt.model = _MODEL if opt.beta == '': opt.beta = _BETA if opt.mu == '': opt.mu = _MU opt.gamma = _GAMMA opt.alpha = _ALPHA if opt.norm == None: opt.norm = _NORM modelname = '{0}_{1}_{2:0.1f}_{3:0.1f}'.format(opt.prefix, opt.model, opt.beta, opt.mu) modelpath = 'model/' + modelname + '.pth' torch.cuda.set_device(opt.gpu) device = torch.device('cuda:{0}'.format(opt.gpu)) now = datetime.now() curtime = now.isoformat() run_dir = "runs/{0}_{1}_ongoing".format(curtime[0:16], modelname) resultname = '{2}/result_{0}_{1}.txt'.format(modelname, opt.num_epochs, run_dir) n_ch = 64 n_hidden = 5 n_resblock = 4 prompt = '' prompt += ('====================================\n') prompt += run_dir + '\n' for arg in vars(opt): prompt = '{0}{1} : {2}\n'.format(prompt, arg, getattr(opt, arg)) prompt += ('====================================\n') print(prompt, end='') # opt.model = 'svhn_mnist' # opt.model = 'mnist_usps' # opt.model = 'usps_mnist' # opt.model = 'cifar10_stl10' # opt.model = 'stl10_cifar10' # opt.model = 'svhn_svhn' # opt.model = 'mnist_mnist' # opt.model = 'usps_usps' # opt.model = 'svhn_usps' ######################### #### DATASET ######################### modelsplit = opt.model.split('_') if (modelsplit[0] == 'mnist' or modelsplit[0] == 'usps') and modelsplit[1] != 'svhn': n_c_in = 1 # number of color channels else: n_c_in = 3 # number of color channels if (modelsplit[1] == 'mnist' or modelsplit[1] == 'usps') and modelsplit[0] != 'svhn': n_c_out = 1 # number of color channels else: n_c_out = 3 # number of color channels trainset, trainset2, testset = utils.load_data(opt=opt) train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, drop_last=True, sampler=InfiniteSampler( len(trainset))) # model train_loader2 = torch.utils.data.DataLoader(trainset2, batch_size=opt.batch_size, drop_last=True, sampler=InfiniteSampler( len(trainset2))) # model test_loader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=True, drop_last=True) # model n_sample = max(len(trainset), len(trainset2)) iter_per_epoch = n_sample // opt.batch_size + 1 src_train_iter = iter(train_loader) tgt_train_iter = iter(train_loader2) if opt.norm == True: X_min = -1 # 0.5 mormalize 는 0~1 X_max = 1 else: X_min = trainset.data.min() X_max = trainset.data.max() # pdb.set_trace() ######################### #### Model ######################### if modelsplit[0] == 'svhn' or modelsplit[1] == 'svhn' or \ modelsplit[0] == 'usps' or modelsplit[0] == 'cifar10' or \ modelsplit[0] == 'stl10': model1 = conv9(p=opt.dropout_probability).cuda( ) # 3x32x32 -> 1x128x1x1 (before FC) model2 = conv9(p=opt.dropout_probability).cuda( ) # 3x32x32 -> 1x128x1x1 (before FC) else: model1 = conv3(p=opt.dropout_probability).cuda( ) # 1x28x28 -> 1x128x4x4 (before FC) model2 = conv3(p=opt.dropout_probability).cuda( ) # 1x28x28 -> 1x128x4x4 (before FC) dropout_mask1 = torch.randint(2, (1, 128, 1, 1), dtype=torch.float).cuda() # dropout_mask1 = torch.randint(2,(1,128,4,4), dtype=torch.float).cuda() weights_init_gaussian = weights_init('gaussian') for X, Y in train_loader: res_x = X.shape[-1] break for X, Y in train_loader2: res_y = X.shape[-1] break gen_st = Generator(n_hidden=n_hidden, n_resblock=n_resblock, \ n_ch=n_ch, res=res_x, n_c_in=n_c_in, n_c_out=n_c_out).cuda() gen_ts = Generator(n_hidden=n_hidden, n_resblock=n_resblock, \ n_ch=n_ch, res=res_y, n_c_in=n_c_out, n_c_out=n_c_in).cuda() dis_s = Discriminator(n_ch=n_ch, res=res_x, n_c_in=n_c_in).cuda() dis_t = Discriminator(n_ch=n_ch, res=res_y, n_c_in=n_c_out).cuda() gen_st.apply(weights_init_gaussian) gen_ts.apply(weights_init_gaussian) dis_s.apply(weights_init_gaussian) dis_t.apply(weights_init_gaussian) pool_size = 50 fake_src_x_pool = ImagePool(pool_size * opt.batch_size) fake_tgt_x_pool = ImagePool(pool_size * opt.batch_size) ######################### #### Loss ######################### config2 = { 'lr': opt.learning_rate, 'weight_decay': opt.weight_decay, 'betas': (0.5, 0.999) } opt_gen = torch.optim.Adam( chain(gen_st.parameters(), gen_ts.parameters(), model1.parameters(), model2.parameters()), **config2) opt_dis = torch.optim.Adam(chain(dis_s.parameters(), dis_t.parameters()), **config2) loss_CE = torch.nn.CrossEntropyLoss().cuda() loss_KLD = torch.nn.KLDivLoss(reduction='batchmean').cuda() loss_LS = GANLoss(device, use_lsgan=True) ######################### #### argument print ######################### writer = SummaryWriter(run_dir) f = open(resultname, 'w') f.write(prompt) f.close() ######################### #### Run ######################### if os.path.isfile(opt.pretrained): modelpath = opt.pretrained print("model load..", modelpath) checkpoint = torch.load(modelpath, map_location='cuda:{0}'.format(opt.gpu)) dropout_mask1 = checkpoint['dropout_mask1'] else: modelpath = 'model/{0}'.format(modelname) os.makedirs(modelpath, exist_ok=True) print("model train..") print(modelname) niter = 0 epoch = 0 while True: model1.train() model2.train() niter += 1 src_x, src_y = next(src_train_iter) tgt_x, tgt_y = next(tgt_train_iter) src_x = src_x.cuda() src_y = src_y.cuda() tgt_x = tgt_x.cuda() fake_tgt_x = gen_st(src_x) fake_src_x = gen_ts(tgt_x) fake_back_src_x = gen_ts(fake_tgt_x) if opt.prefix == 'tranlsation_noCE': loss_gen = opt.gamma * loss_LS(dis_s(fake_src_x), True) loss_gen += opt.alpha * loss_LS(dis_t(fake_tgt_x), True) else: loss_gen = opt.beta * loss_CE(model2(fake_tgt_x), src_y) loss_gen += opt.mu * loss_CE(model1(src_x), src_y) loss_gen += opt.gamma * loss_LS(dis_s(fake_src_x), True) loss_gen += opt.alpha * loss_LS(dis_t(fake_tgt_x), True) loss_dis_s = opt.gamma * loss_LS( dis_s(fake_src_x_pool.query(fake_src_x)), False) loss_dis_s += opt.gamma * loss_LS(dis_s(src_x), True) loss_dis_t = opt.alpha * loss_LS( dis_t(fake_tgt_x_pool.query(fake_tgt_x)), False) loss_dis_t += opt.alpha * loss_LS(dis_t(tgt_x), True) loss_dis = loss_dis_s + loss_dis_t for optim, loss in zip([opt_dis, opt_gen], [loss_dis, loss_gen]): optim.zero_grad() loss.backward(retain_graph=True) optim.step() if niter % opt.print_delay == 0 and niter > 0: with torch.no_grad(): ########################## loss_dis_s1 = opt.gamma * loss_LS( dis_s(fake_src_x_pool.query(fake_src_x)), False) loss_dis_s2 = opt.gamma * loss_LS(dis_s(src_x), True) loss_dis_t1 = opt.alpha * loss_LS( dis_t(fake_tgt_x_pool.query(fake_tgt_x)), False) loss_dis_t2 = opt.alpha * loss_LS(dis_t(tgt_x), True) loss_gen_s = opt.gamma * loss_LS(dis_s(fake_src_x), True) loss_gen_t = opt.alpha * loss_LS(dis_t(fake_tgt_x), True) loss_gen_CE_t = opt.beta * loss_CE(model2(fake_tgt_x), src_y) loss_gen_CE_s = opt.mu * loss_CE(model1(src_x), src_y) ########################### print('epoch {0} ({1}/{2}) '.format(epoch, (niter % iter_per_epoch), iter_per_epoch ) \ + 'dis_s1 {0:02.4f}, dis_s2 {1:02.4f}, '.format(loss_dis_s1.item(), loss_dis_s2.item()) \ + 'dis_t1 {0:02.4f}, dis_t2 {1:02.4f}, '.format(loss_dis_t1.item(), loss_dis_t2.item()) \ + 'loss_gen_s {0:02.4f}, loss_gen_t {1:02.4f} '.format(loss_gen_s.item(), loss_gen_t.item()) + 'loss_gen_CE_t {0:02.4f}, loss_gen_CE_s {1:02.4f}'.format(loss_gen_CE_t.item(), loss_gen_CE_s.item()), end='\r') writer.add_scalar('dis/src', loss_dis_s.item(), niter) writer.add_scalar('dis/src1', loss_dis_s1.item(), niter) writer.add_scalar('dis/src2', loss_dis_s2.item(), niter) writer.add_scalar('dis/tgt', loss_dis_t.item(), niter) writer.add_scalar('dis/tgt1', loss_dis_t1.item(), niter) writer.add_scalar('dis/tgt2', loss_dis_t2.item(), niter) writer.add_scalar('gen', loss_gen.item(), niter) writer.add_scalar('gen/src', loss_gen_s.item(), niter) writer.add_scalar('gen/tgt', loss_gen_t.item(), niter) writer.add_scalar( 'CE/tgt', loss_CE(model2(fake_tgt_x), src_y).item(), niter) writer.add_scalar('CE/src', loss_CE(model1(src_x), src_y).item(), niter) # pdb.set_trace() if niter % (opt.print_delay * 10) == 0: data_grid = [] for x in [ src_x, fake_tgt_x, fake_back_src_x, tgt_x, fake_src_x ]: x = x.to(torch.device('cpu')) if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) # grayscale2rgb data_grid.append(x) grid = make_grid(torch.cat(tuple(data_grid), dim=0), normalize=True, range=(X_min, X_max), nrow=opt.batch_size) # for SVHN? writer.add_image('generated_{0}'.format(opt.prefix), grid, niter) if niter % iter_per_epoch == 0 and niter > 0: with torch.no_grad(): epoch = niter // iter_per_epoch model1.eval() model2.eval() avgaccuracy1 = 0 avgaccuracy2 = 0 n = 0 nagree = 0 for X, Y in test_loader: n += X.size()[0] X_test = X.cuda() Y_test = Y.cuda() prediction1 = model1(X_test) # predicted_classes1 = torch.argmax(prediction1, 1) correct_count1 = (predicted_classes1 == Y_test) testaccuracy1 = correct_count1.float().sum() avgaccuracy1 += testaccuracy1 prediction2 = model2(X_test) # predicted_classes2 = torch.argmax(prediction2, 1) correct_count2 = (predicted_classes2 == Y_test) testaccuracy2 = correct_count2.float().sum() avgaccuracy2 += testaccuracy2 avgaccuracy1 = (avgaccuracy1 / n) * 100 avgaccuracy2 = (avgaccuracy2 / n) * 100 agreement = (predicted_classes1 == predicted_classes2) nagree = nagree + (agreement).int().sum() writer.add_scalar('accuracy/tgt', avgaccuracy1, niter) writer.add_scalar('accuracy/src', avgaccuracy2, niter) writer.add_scalar('agreement', (nagree / n) * 100, niter) f = open(resultname, 'a') f.write('epoch : {0}\n'.format(epoch)) f.write('\tloss_gen_s : {0:0.4f}\n'.format( loss_gen_s.item())) f.write('\tloss_gen_t : {0:0.4f}\n'.format( loss_gen_t.item())) f.write('\tloss_gen_CE_t : {0:0.4f}\n'.format( loss_gen_CE_t.item())) f.write('\tloss_gen_CE_s : {0:0.4f}\n'.format( loss_gen_CE_s.item())) f.write('\tloss_dis_s1 : {0:0.4f}\n'.format( loss_dis_s1.item())) f.write('\tloss_dis_t1 : {0:0.4f}\n'.format( loss_dis_t1.item())) f.write('\tloss_dis_s2 : {0:0.4f}\n'.format( loss_dis_s2.item())) f.write('\tloss_dis_t2 : {0:0.4f}\n'.format( loss_dis_t2.item())) f.write( '\tavgaccuracy_tgt : {0:0.2f}\n'.format(avgaccuracy1)) f.write( '\tavgaccuracy_src : {0:0.2f}\n'.format(avgaccuracy2)) f.write('\tagreement : {0}\n'.format(nagree)) f.close() if epoch >= opt.num_epochs: os.rename(run_dir, run_dir[:-8]) break
def experiment(exp, affine, num_epochs): writer = SummaryWriter() log_dir = 'log/{:s}/sbada'.format(exp) os.makedirs(log_dir, exist_ok=True) device = torch.device('cuda') config = get_config('config.yaml') alpha = float(config['weight']['alpha']) beta = float(config['weight']['beta']) gamma = float(config['weight']['gamma']) mu = float(config['weight']['mu']) new = float(config['weight']['new']) eta = 0.0 batch_size = int(config['batch_size']) pool_size = int(config['pool_size']) lr = float(config['lr']) weight_decay = float(config['weight_decay']) src, tgt = load_source_target_datasets(exp) n_ch_s = src.train_X.shape[1] # number of color channels n_ch_t = tgt.train_X.shape[1] # number of color channels res = src.train_X.shape[-1] # size of image n_classes = src.n_classes train_tfs = get_composed_transforms(train=True, hflip=False) test_tfs = get_composed_transforms(train=False, hflip=False) src_train = DADataset(src.train_X, src.train_y, train_tfs, affine) tgt_train = DADataset(tgt.train_X, None, train_tfs, affine) tgt_test = DADataset(tgt.test_X, tgt.test_y, test_tfs, affine) del src, tgt n_sample = max(len(src_train), len(tgt_train)) iter_per_epoch = n_sample // batch_size + 1 weights_init_kaiming = weights_init('kaiming') weights_init_gaussian = weights_init('gaussian') cls_s = LenetClassifier(n_classes, n_ch_s, res).to(device) cls_t = LenetClassifier(n_classes, n_ch_t, res).to(device) cls_s.apply(weights_init_kaiming) cls_t.apply(weights_init_kaiming) gen_s_t_params = {'res': res, 'n_c_in': n_ch_s, 'n_c_out': n_ch_t} gen_t_s_params = {'res': res, 'n_c_in': n_ch_t, 'n_c_out': n_ch_s} gen_s_t = Generator(**{**config['gen_init'], **gen_s_t_params}).to(device) gen_t_s = Generator(**{**config['gen_init'], **gen_t_s_params}).to(device) gen_s_t.apply(weights_init_gaussian) gen_t_s.apply(weights_init_gaussian) dis_s_params = {'res': res, 'n_c_in': n_ch_s} dis_t_params = {'res': res, 'n_c_in': n_ch_t} dis_s = Discriminator(**{**config['dis_init'], **dis_s_params}).to(device) dis_t = Discriminator(**{**config['dis_init'], **dis_t_params}).to(device) dis_s.apply(weights_init_gaussian) dis_t.apply(weights_init_gaussian) config = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0.5, 0.999)} opt_gen = Adam( chain(gen_s_t.parameters(), gen_t_s.parameters(), cls_s.parameters(), cls_t.parameters()), **config) opt_dis = Adam(chain(dis_s.parameters(), dis_t.parameters()), **config) calc_ls = GANLoss(device, use_lsgan=True) calc_ce = F.cross_entropy fake_src_x_pool = ImagePool(pool_size * batch_size) fake_tgt_x_pool = ImagePool(pool_size * batch_size) src_train_iter = iter( DataLoader(src_train, batch_size=batch_size, num_workers=4, sampler=InfiniteSampler(len(src_train)))) tgt_train_iter = iter( DataLoader(tgt_train, batch_size=batch_size, num_workers=4, sampler=InfiniteSampler(len(tgt_train)))) tgt_test_loader = DataLoader(tgt_test, batch_size=batch_size * 4, num_workers=4) print('Training...') cls_s.train() cls_t.train() niter = 0 while True: niter += 1 src_x, src_y = next(src_train_iter) tgt_x = next(tgt_train_iter) src_x, src_y = src_x.to(device), src_y.to(device) tgt_x = tgt_x.to(device) if niter >= num_epochs * 0.75 * iter_per_epoch: eta = config['weight']['eta'] fake_tgt_x = gen_s_t(src_x) fake_back_src_x = gen_t_s(fake_tgt_x) fake_src_x = gen_t_s(tgt_x) with torch.no_grad(): fake_src_pseudo_y = torch.max(cls_s(fake_src_x), dim=1)[1] # eq2 loss_gen = beta * calc_ce(cls_t(fake_tgt_x), src_y) loss_gen += mu * calc_ce(cls_s(src_x), src_y) # eq3 loss_gen += gamma * calc_ls(dis_s(fake_src_x), True) loss_gen += alpha * calc_ls(dis_t(fake_tgt_x), True) # eq5 loss_gen += eta * calc_ce(cls_s(fake_src_x), fake_src_pseudo_y) # eq6 loss_gen += new * calc_ce(cls_s(fake_back_src_x), src_y) # do not backpropagate loss to generator fake_tgt_x = fake_tgt_x.detach() fake_src_x = fake_src_x.detach() fake_back_src_x = fake_back_src_x.detach() # eq3 loss_dis_s = gamma * calc_ls(dis_s(fake_src_x_pool.query(fake_src_x)), False) loss_dis_s += gamma * calc_ls(dis_s(src_x), True) loss_dis_t = alpha * calc_ls(dis_t(fake_tgt_x_pool.query(fake_tgt_x)), False) loss_dis_t += alpha * calc_ls(dis_t(tgt_x), True) loss_dis = loss_dis_s + loss_dis_t for opt, loss in zip([opt_dis, opt_gen], [loss_dis, loss_gen]): opt.zero_grad() loss.backward(retain_graph=True) opt.step() if niter % 100 == 0 and niter > 0: writer.add_scalar('dis/src', loss_dis_s.item(), niter) writer.add_scalar('dis/tgt', loss_dis_t.item(), niter) writer.add_scalar('gen', loss_gen.item(), niter) if niter % iter_per_epoch == 0: epoch = niter // iter_per_epoch if epoch % 10 == 0: data = [] for x in [ src_x, fake_tgt_x, fake_back_src_x, tgt_x, fake_src_x ]: x = x.to(torch.device('cpu')) if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) # grayscale2rgb data.append(x) grid = make_grid(torch.cat(tuple(data), dim=0), normalize=True, range=(-1.0, 1.0)) writer.add_image('generated', grid, epoch) cls_t.eval() n_err = 0 with torch.no_grad(): for tgt_x, tgt_y in tgt_test_loader: prob_y = F.softmax(cls_t(tgt_x.to(device)), dim=1) pred_y = torch.max(prob_y, dim=1)[1] pred_y = pred_y.to(torch.device('cpu')) n_err += (pred_y != tgt_y).sum().item() writer.add_scalar('err_tgt', n_err / len(tgt_test), epoch) cls_t.train() if epoch % 50 == 0: models_dict = { 'cls_s': cls_s, 'cls_t': cls_t, 'dis_s': dis_s, 'dis_t': dis_t, 'gen_s_t': gen_s_t, 'gen_t_s': gen_t_s } filename = '{:s}/epoch{:d}.tar'.format(log_dir, epoch) save_models_dict(models_dict, filename) if epoch >= num_epochs: break
def experiment(exp, num_epochs, pretrain, consistency): config = get_config('config.yaml') identifier = '{:s}_ndf{:d}_ngf{:d}'.format(consistency, config['dis']['ndf'], config['gen']['ngf']) log_dir = 'log/{:s}/{:s}'.format(exp, identifier) snapshot_dir = 'snapshot/{:s}/{:s}'.format(exp, identifier) writer = SummaryWriter(log_dir=log_dir) os.makedirs(snapshot_dir, exist_ok=True) shutil.copy('config.yaml', '{:s}/{:s}'.format(snapshot_dir, 'config.yaml')) batch_size = int(config['batch_size']) pool_size = int(config['pool_size']) lr = float(config['lr']) weight_decay = float(config['weight_decay']) device = torch.device('cuda') src, tgt = load_source_target_datasets(exp) n_ch_s = src.train_X.shape[1] # number of color channels n_ch_t = tgt.train_X.shape[1] # number of color channels n_class = src.n_classes train_tfs = get_composed_transforms() test_tfs = get_composed_transforms() src_train = DADataset(src.train_X, src.train_y, train_tfs) src_test = DADataset(src.test_X, src.test_y, train_tfs) tgt_train = DADataset(tgt.train_X, tgt.train_y, train_tfs) tgt_train = SubsetDataset(tgt_train, range(1000)) # fix indices tgt_test = DADataset(tgt.test_X, tgt.test_y, test_tfs) del src, tgt n_sample = max(len(src_train), len(tgt_train)) iter_per_epoch = n_sample // batch_size + 1 cls_s = Classifier(n_class, n_ch_s).to(device) cls_t = Classifier(n_class, n_ch_t).to(device) if not pretrain: load_model(cls_s, 'snapshot/{:s}/pretrain_cls_s.tar'.format(exp)) gen_s_t_params = {'input_nc': n_ch_s, 'output_nc': n_ch_t} gen_t_s_params = {'input_nc': n_ch_t, 'output_nc': n_ch_s} gen_s_t = define_G(**{**config['gen'], **gen_s_t_params}).to(device) gen_t_s = define_G(**{**config['gen'], **gen_t_s_params}).to(device) dis_s = define_D(**{**config['dis'], 'input_nc': n_ch_s}).to(device) dis_t = define_D(**{**config['dis'], 'input_nc': n_ch_t}).to(device) opt_config = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0.5, 0.99)} opt_gen = Adam(chain(gen_s_t.parameters(), gen_t_s.parameters(), \ cls_s.parameters(), cls_t.parameters()), **opt_config) opt_dis = Adam(chain(dis_s.parameters(), dis_t.parameters()), **opt_config) calc_ls = GANLoss(device, use_lsgan=True).to(device) calc_ce = torch.nn.CrossEntropyLoss().to(device) calc_l1 = torch.nn.L1Loss().to(device) fake_src_x_pool = ImagePool(pool_size * batch_size) fake_tgt_x_pool = ImagePool(pool_size * batch_size) src_train_iter = iter( DataLoader(src_train, batch_size=batch_size, num_workers=4, sampler=InfiniteSampler(len(src_train)))) tgt_train_iter = iter( DataLoader(tgt_train, batch_size=batch_size, num_workers=4, sampler=InfiniteSampler(len(tgt_train)))) src_test_loader = DataLoader(src_test, batch_size=batch_size * 4, num_workers=4) tgt_test_loader = DataLoader(tgt_test, batch_size=batch_size * 4, num_workers=4) print('Training...') cls_s.train() cls_t.train() niter = 0 if pretrain: while True: niter += 1 src_x, src_y = next(src_train_iter) loss = calc_ce(cls_s(src_x.to(device)), src_y.to(device)) opt_gen.zero_grad() loss.backward() opt_gen.step() if niter % iter_per_epoch == 0: epoch = niter // iter_per_epoch n_err = evaluate_classifier(cls_s, tgt_test_loader, device) print(epoch, n_err / len(tgt_test)) # n_err = evaluate_classifier(cls_s, src_test_loader, device) # print(epoch, n_err / len(src_test)) if epoch >= num_epochs: save_model(cls_s, '{:s}/pretrain_cls_s.tar'.format(snapshot_dir)) break exit() while True: niter += 1 src_x, src_y = next(src_train_iter) tgt_x, tgt_y = next(tgt_train_iter) src_x, src_y = src_x.to(device), src_y.to(device) tgt_x, tgt_y = tgt_x.to(device), tgt_y.to(device) fake_tgt_x = gen_s_t(src_x) fake_back_src_x = gen_t_s(fake_tgt_x) fake_src_x = gen_t_s(tgt_x) fake_back_tgt_x = gen_s_t(fake_src_x) ################# # discriminator # ################# loss_dis_s = calc_ls(dis_s(fake_src_x_pool.query(fake_src_x.detach())), False) loss_dis_s += calc_ls(dis_s(src_x), True) loss_dis_t = calc_ls(dis_t(fake_tgt_x_pool.query(fake_tgt_x.detach())), False) loss_dis_t += calc_ls(dis_t(tgt_x), True) loss_dis = loss_dis_s + loss_dis_t ########################## # generator + classifier # ########################## # classification loss_gen_cls_s = calc_ce(cls_s(src_x), src_y) loss_gen_cls_t = calc_ce(cls_t(tgt_x), tgt_y) loss_gen_cls = loss_gen_cls_s + loss_gen_cls_t # augmented cycle consistency if consistency == 'augmented': loss_gen_aug_s = calc_ce(cls_s(fake_src_x), tgt_y) loss_gen_aug_s += calc_ce(cls_s(fake_back_src_x), src_y) loss_gen_aug_t = calc_ce(cls_t(fake_tgt_x), src_y) loss_gen_aug_t += calc_ce(cls_t(fake_back_tgt_x), tgt_y) loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t elif consistency == 'relaxed': loss_gen_aug_s = calc_ce(cls_s(fake_back_src_x), src_y) loss_gen_aug_t = calc_ce(cls_t(fake_back_tgt_x), tgt_y) loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t elif consistency == 'simple': loss_gen_aug_s = calc_ce(cls_s(fake_src_x), tgt_y) loss_gen_aug_t = calc_ce(cls_t(fake_tgt_x), src_y) loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t elif consistency == 'cycle': loss_gen_aug_s = calc_l1(fake_back_src_x, src_x) loss_gen_aug_t = calc_l1(fake_back_tgt_x, tgt_x) loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t else: raise NotImplementedError # deceive discriminator loss_gen_adv_s = calc_ls(dis_s(fake_src_x), True) loss_gen_adv_t = calc_ls(dis_t(fake_tgt_x), True) loss_gen_adv = loss_gen_adv_s + loss_gen_adv_t loss_gen = loss_gen_cls + loss_gen_aug + loss_gen_adv opt_dis.zero_grad() loss_dis.backward() opt_dis.step() opt_gen.zero_grad() loss_gen.backward() opt_gen.step() if niter % 100 == 0 and niter > 0: writer.add_scalar('dis/src', loss_dis_s.item(), niter) writer.add_scalar('dis/tgt', loss_dis_t.item(), niter) writer.add_scalar('gen/cls_s', loss_gen_cls_s.item(), niter) writer.add_scalar('gen/cls_t', loss_gen_cls_t.item(), niter) writer.add_scalar('gen/aug_s', loss_gen_aug_s.item(), niter) writer.add_scalar('gen/aug_t', loss_gen_aug_t.item(), niter) writer.add_scalar('gen/adv_s', loss_gen_adv_s.item(), niter) writer.add_scalar('gen/adv_t', loss_gen_adv_t.item(), niter) if niter % iter_per_epoch == 0: epoch = niter // iter_per_epoch if epoch % 1 == 0: data = [] for x in [ src_x, fake_tgt_x, fake_back_src_x, tgt_x, fake_src_x, fake_back_tgt_x ]: x = x.to(torch.device('cpu')) if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) # grayscale2rgb data.append(x) grid = make_grid(torch.cat(tuple(data), dim=0), nrow=16, normalize=True, range=(-1.0, 1.0)) writer.add_image('generated', grid, epoch) n_err = evaluate_classifier(cls_t, tgt_test_loader, device) writer.add_scalar('err_tgt', n_err / len(tgt_test), epoch) if epoch % 50 == 0: models_dict = { 'cls_s': cls_s, 'cls_t': cls_t, 'dis_s': dis_s, 'dis_t': dis_t, 'gen_s_t': gen_s_t, 'gen_t_s': gen_t_s } filename = '{:s}/epoch{:d}.tar'.format(snapshot_dir, epoch) save_models_dict(models_dict, filename) if epoch >= num_epochs: break
def build_models(self): ################### encoders ######################################### self.E_image = None self.E_text = None use_con = False use_sigmoid = True if self.opt.c_gan_mode == 'dcgan' else False if self.opt.c_type == 'image': if self.opt.model == 'supervised': self.E_image = E_ResNet_Local(input_nc=self.opt.output_nc, output_nc=self.opt.nc, nef=self.opt.nef, n_blocks=self.opt.e_blocks,norm_type=self.opt.norm) elif self.opt.model == 'unsupervised': self.E_image = CE_ResNet_Local(input_nc=self.opt.output_nc, output_nc=self.opt.nc, nef=self.opt.nef, n_blocks=self.opt.e_blocks, c_dim=self.opt.ne, norm_type=self.opt.norm) elif self.opt.c_type == 'text': use_con = True self.E_text = RNN_ENCODER(self.opt.n_words, nhidden=self.opt.nc) state_dict = torch.load(self.opt.E_text_path, map_location=lambda storage, loc: storage) self.E_text.load_state_dict(state_dict) for p in self.E_text.parameters(): p.requires_grad = False print('Load text encoder successful') self.E_text.eval() elif self.opt.c_type == 'image_text': use_con = True self.E_image = E_ResNet_Global(input_nc=self.opt.output_nc, output_nc=self.opt.ne, nef=self.opt.nef, n_blocks=self.opt.e_blocks,norm_type=self.opt.norm) self.E_text = RNN_ENCODER(self.opt.n_words, nhidden=self.opt.nc) state_dict = torch.load(self.opt.E_text_path, map_location=lambda storage, loc: storage) self.E_text.load_state_dict(state_dict) for p in self.E_text.parameters(): p.requires_grad = False print('Load text encoder successful') self.E_text.eval() elif self.opt.c_type == 'label': use_con = True elif self.opt.c_type == 'image_label': use_con = True self.E_image = E_ResNet_Global(input_nc=self.opt.output_nc, output_nc=self.opt.ne, nef=self.opt.nef, n_blocks=self.opt.e_blocks,norm_type=self.opt.norm) else: raise('Non conditioanl type of {}'.format(self.opt.c_type)) ################### generator ######################################### self.G = RAG_NET(input_nc=self.opt.input_nc, ngf=self.opt.ngf, nc=self.opt.nc, ne=self.opt.ne,norm_type=self.opt.norm) if self.opt.isTrain: ################### discriminators ##################################### self.Ds = [] self.Ds2 = None bnf = 3 if self.opt.fineSize <=128 else 4 self.Ds.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=bnf, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm)) self.Ds.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm)) self.Ds.append(D_NET_Multi(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm)) if self.opt.model == 'unsupervised' and self.opt.c_type == 'image': self.Ds2 = [] self.Ds2.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=bnf, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm)) self.Ds2.append(D_NET(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm)) self.Ds2.append(D_NET_Multi(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=4, nc=self.opt.nc, use_con=use_con, use_sigmoid=use_sigmoid,norm_type=self.opt.norm)) ################### init_weights ######################################## self.G.apply(weights_init(self.opt.init_type)) for i in range(self.D_len): self.Ds[i].apply(weights_init(self.opt.init_type)) if self.Ds2 is not None: for i in range(self.D_len): self.Ds2[i].apply(weights_init(self.opt.init_type)) if self.E_image is not None: self.E_image.apply(weights_init(self.opt.init_type)) ################### use GPU ############################################# self.G.cuda() for i in range(self.D_len): self.Ds[i].cuda() if self.Ds2 is not None: for i in range(self.D_len): self.Ds2[i].cuda() if self.E_image is not None: self.E_image.cuda() if self.E_text is not None: self.E_text.cuda() ################### set criterion ######################################## self.criterionGAN = GANLoss(mse_loss=True) if use_con: self.criterionCGAN = GANLoss(mse_loss = not use_sigmoid) self.criterionKL = KL_loss ################## define optimizers ##################################### self.define_optimizers()