def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) discriminator = models['discriminator'].to(args.device) print(generator) print(discriminator) optim_discriminator = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() discriminator.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) optim_discriminator.zero_grad() d_pos_loss, d_neg_loss, gp = disc_loss_generation( data, args.z_dim, discriminator, generator, args.device) d_pos_loss.backward() d_neg_loss.backward(mone) (10 * gp).backward() optim_discriminator.step() optim_generator.zero_grad() t_loss = transfer_loss(args.train_batch_size, args.z_dim, discriminator, generator, args.device) t_loss.backward() optim_generator.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) noise = torch.randn(data.shape[0], args.z_dim, device=args.device) evaluate(args.visualiser, noise, data, generator, i) d_loss = (d_pos_loss - d_neg_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Discriminator loss') args.visualiser.plot(step=i, data=t_loss.detach().cpu().numpy(), title=f'Generator loss') args.visualiser.plot(step=i, data=gp.detach().cpu().numpy(), title=f'Gradient Penalty') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader, valid_loader, test_loader = args.loader models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) classifier = models['classifier'].to(args.device) print(classifier) optim_classifier = optim.SGD(classifier.parameters(), lr=args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optim_classifier, args.iterations) it = iter(train_loader) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() best_accuracy = 0 for i in range(iteration, args.iterations): classifier.train() batch, it = sample(it, train_loader) data = batch[0].to(args.device) classes = batch[1].to(args.device) optim_classifier.zero_grad() loss = classification_loss(data, classes, classifier) loss.backward() optim_classifier.step() scheduler.step() if i % args.evaluate == 0: classifier.eval() print('Iter: %s' % i, time.time() - t0) valid_accuracy = evaluate(args.visualiser, i, valid_loader, classifier, 'valid', args.device) test_accuracy = evaluate(args.visualiser, i, test_loader, classifier, 'test', args.device) loss = loss.detach().cpu().numpy() args.visualiser.plot(loss, title=f'loss', step=i) if valid_accuracy > best_accuracy: best_accuracy = valid_accuracy save_models(models, i, args.model_path, args.checkpoint) t0 = time.time()
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 train_loader3, test_loader3 = args.loaders3 #train_loader4, test_loader4 = args.loaders4 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) criticx1 = models['criticx1'].to(args.device) criticx2 = models['criticx2'].to(args.device) criticy1 = models['criticy1'].to(args.device) criticy2 = models['criticy2'].to(args.device) criticz1 = models['criticz1'].to(args.device) criticz2 = models['criticz2'].to(args.device) #criticw1 = models['criticw1'].to(args.device) #criticw2 = models['criticw2'].to(args.device) print(generator) print(criticx1) print(criticy1) print(criticz1) #print(criticw1) optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticz1 = optim.Adam(criticz1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticz2 = optim.Adam(criticz2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) #optim_criticw1 = optim.Adam(criticw1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) #optim_criticw2 = optim.Adam(criticw2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1, iter2, iter3 = iter(train_loader1), iter(train_loader2), iter( train_loader3) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1, titer2, titer3, = iter(test_loader1), iter(test_loader2), iter( test_loader3) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() criticx1.train() criticx2.train() criticy1.train() criticy2.train() criticz1.train() criticz2.train() #criticw1.train() #criticw2.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) data = batchx.to(args.device) batchx, iter1 = sample(iter1, train_loader1) input_data = batchx.to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy.to(args.device) batchz, iter3 = sample(iter3, train_loader3) dataz = batchz.to(args.device) #batchw, iter4 = sample(iter4, train_loader4) #dataw = batchw[0].to(args.device) optim_criticx1.zero_grad() optim_criticx2.zero_grad() r_loss, g_loss, p = disc_loss_generation(input_data, data, args.eps, args.lp, criticx1, criticx2) (r_loss + g_loss + p).backward(mone) optim_criticx1.step() optim_criticx2.step() optim_criticy1.zero_grad() optim_criticy2.zero_grad() r_loss, g_loss, p = disc_loss_generation(input_data, datay, args.eps, args.lp, criticy1, criticy2) (r_loss + g_loss + p).backward(mone) optim_criticy1.step() optim_criticy2.step() optim_criticz1.zero_grad() optim_criticz2.zero_grad() r_loss, g_loss, p = disc_loss_generation(input_data, dataz, args.eps, args.lp, criticz1, criticz2) (r_loss + g_loss + p).backward(mone) optim_criticz1.step() optim_criticz2.step() #optim_criticw1.zero_grad() #optim_criticw2.zero_grad() #r_loss, g_loss, p = disc_loss_generation(input_data, dataw, args.eps, args.lp, criticw1, criticw2) #(r_loss + g_loss + p).backward(mone) #optim_criticw1.step() #optim_criticw2.step() optim_generator.zero_grad() t_ = Dirichlet(torch.FloatTensor([1., 1., 1.])).sample().to(args.device) t = torch.stack([t_] * input_data.shape[0]) tinputdata = torch.cat([input_data] * args.nt) tdata = torch.cat([data] * args.nt) tdatay = torch.cat([datay] * args.nt) tdataz = torch.cat([dataz] * args.nt) #tdataw = torch.cat([dataw]*args.nt) t_lossx = transfer_loss(tinputdata, tdata, args.nt, t, args.eps, args.lp, criticx1, criticx2, generator) t_lossy = transfer_loss(tinputdata, tdatay, args.nt, t, args.eps, args.lp, criticy1, criticy2, generator) t_lossz = transfer_loss(tinputdata, tdataz, args.nt, t, args.eps, args.lp, criticz1, criticz2, generator) #t_lossw = transfer_loss(tinputdata, tdataw, args.nt, t, args.eps, args.lp, criticw1, criticw2, generator) t_loss = (t_[0] * t_lossx + t_[1] * t_lossy + t_[2] * t_lossz).sum() t_loss.backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) batchx, titer1 = sample(titer1, test_loader1) datax = batchx.to(args.device) #labelx = batchx[1].to(args.device) batchy, titer2 = sample(titer2, test_loader2) datay = batchy.to(args.device) #labely = batchy[1].to(args.device) batchz, titer3 = sample(titer3, test_loader3) dataz = batchz.to(args.device) #labelz = batchz[1].to(args.device) evaluate(args.visualiser, datax, datay, dataz, generator, 'x', args.device) #batchw, titerw = sample(titer4, test_loader4) #dataw = batchw[0].to(args.device) #labelw = batchw[1].to(args.device) #evaluate_3d(args.visualiser, datax, datay, dataz, dataw, labelx, labely, labelz, labelw, generator, 'x', args.device) d_loss = (r_loss + g_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y') args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss X') args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss Y') #args.visualiser.plot(step=i, data=t_lossw.detach().cpu().numpy(), title=f'Generator loss W') args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) critic1 = models['critic1'].to(args.device) critic2 = models['critic2'].to(args.device) print(generator) print(critic1) print(critic2) optim_critic1 = optim.Adam(critic1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_critic2 = optim.Adam(critic2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() critic1.train() critic2.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) if data.shape[0] != args.train_batch_size: batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) optim_critic1.zero_grad() optim_critic2.zero_grad() z = torch.randn(data.shape[0], args.z_dim, device=args.device) gen1 = generator(z).detach() r_loss, g_loss, p = disc_loss_generation(data, gen1, args.eps, args.lp, critic1, critic2) (r_loss + g_loss + p).backward(mone) optim_critic1.step() optim_critic2.step() optim_generator.zero_grad() t_loss = transfer_loss(data, args.eps, args.lp, args.z_dim, critic1, critic2, generator, args.device) t_loss.backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) noise = torch.randn(args.test_batch_size, args.z_dim, device=args.device) evaluate(args.visualiser, noise, data[:args.test_batch_size], generator, i) d_loss = (r_loss + g_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss') args.visualiser.plot(step=i, data=t_loss.detach().cpu().numpy(), title=f'Generator loss') args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) critic1 = models['critic1'].to(args.device) critic2 = models['critic2'].to(args.device) print(generator) print(critic1) optim_critic1 = optim.Adam(critic1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_critic2 = optim.Adam(critic2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1, iter2 = iter(train_loader1), iter(train_loader2) iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1, titer2 = iter(test_loader1), iter(test_loader2) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() critic1.train() critic2.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) data1 = batchx.to(args.device) batchy, iter2 = sample(iter2, train_loader2) data2 = batchy.to(args.device) optim_critic1.zero_grad() r_loss1 = critic_loss(data1, data1, args.alpha, critic1, generator, args.device) r_loss1.backward(mone) optim_critic1.step() optim_critic2.zero_grad() r_loss2 = critic_loss(data1, data2, args.alpha, critic2, generator, args.device) r_loss2.backward(mone) optim_critic2.step() optim_generator.zero_grad() for _ in range(10): t_ = torch.distributions.beta.Beta(args.alpha, args.alpha).sample_n(1).to(args.device) t = torch.stack([t_]*data1.shape[0]) t_loss1 = transfer_loss(data1, data1, t, critic1, generator, args.device)**2 t_loss2 = transfer_loss(data1, data2, t, critic2, generator, args.device)**2 ((1-t_)*t_loss1 + t_*t_loss2).backward() #t_ = torch.FloatTensor([0]).to(args.device) #t = torch.stack([t_]*data1.shape[0]) #gen = generator(data1, t) #reg = F.mse_loss(gen, data1) #(10*reg).backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) batchx, titer1 = sample(titer1, test_loader1) data1 = batchx.to(args.device) batchy, titer2 = sample(titer2, test_loader2) data2 = batchy.to(args.device) evaluate(args.visualiser, data1, data2, args.z_dim, generator, i, args.device) args.visualiser.plot(step=i, data=r_loss1.detach().cpu().numpy(), title=f'Critic loss 1') args.visualiser.plot(step=i, data=r_loss2.detach().cpu().numpy(), title=f'Critic loss 2') args.visualiser.plot(step=i, data=t_loss1.detach().cpu().numpy(), title=f'Generator loss 1') args.visualiser.plot(step=i, data=t_loss2.detach().cpu().numpy(), title=f'Generator loss 2') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) g12 = models['g12'].to(args.device) g21 = models['g21'].to(args.device) d1 = models['d1'].to(args.device) d2 = models['d2'].to(args.device) eval_model = args.evaluation.eval().to(args.device) print(g12) print(g21) print(d1) print(d2) optim_g12 = optim.Adam(g12.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_g21 = optim.Adam(g21.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_d1 = optim.Adam(d1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_d2 = optim.Adam(d2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iter2 = iter(train_loader2) titer1 = iter(test_loader1) titer2 = iter(test_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() for i in range(iteration, args.iterations): g12.train() g21.train() d1.train() d2.train() batchx, iter1 = sample(iter1, train_loader1) data1 = batchx[0].to(args.device) if data1.shape[0] != args.train_batch_size: batchx, iter1 = sample(iter1, train_loader1) data1 = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) data2 = batchy[0].to(args.device) if data2.shape[0] != args.train_batch_size: batchy, iter2 = sample(iter2, train_loader2) data2 = batchy[0].to(args.device) dloss1 = disc_loss(data2, data1, g21, d1) optim_d1.zero_grad() dloss1.backward() optim_d1.step() dloss2 = disc_loss(data1, data2, g12, d2) optim_d2.zero_grad() dloss2.backward() optim_d2.step() gloss1 = generator_loss(data1, g12, g21, d2) optim_g12.zero_grad() optim_g21.zero_grad() gloss1.backward() optim_g12.step() optim_g21.step() gloss2 = generator_loss(data2, g21, g12, d1) optim_g12.zero_grad() optim_g21.zero_grad() gloss2.backward() optim_g12.step() optim_g21.step() if i % args.evaluate == 0: g12.eval() g21.eval() batchx, titer1 = sample(titer1, test_loader1) data1 = batchx[0].to(args.device) batchy, titer2 = sample(titer2, test_loader2) data2 = batchy[0].to(args.device) print('Iter: %s' % i, time.time() - t0) evaluate(args.visualiser, data1, g12, 'x') evaluate(args.visualiser, data2, g21, 'y') evaluate_gen_class_accuracy(args.visualiser, i, test_loader1, eval_model, g12, args.device) evaluate_class_accuracy(args.visualiser, i, test_loader2, eval_model, args.device) args.visualiser.plot(dloss1.cpu().detach().numpy(), title=f'Discriminator loss1', step=i) args.visualiser.plot(dloss2.cpu().detach().numpy(), title=f'Discriminator loss2', step=i) args.visualiser.plot(gloss1.cpu().detach().numpy(), title=f'Generator loss 1-2', step=i) args.visualiser.plot(gloss2.cpu().detach().numpy(), title=f'Generator loss 2-1', step=i) t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) encoder = models['classifier'].to(args.device) contrastive = models['contrastive'].to(args.device) print(encoder) print(contrastive) optim_encoder = optim.Adam(encoder.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_contrastive = optim.Adam(contrastive.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): encoder.train() contrastive.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].float().to(args.device) optim_contrastive.zero_grad() ploss, nloss, gp = contrastive_loss(datax, args.nc, encoder, contrastive, args.device) (ploss - nloss + gp).backward() optim_contrastive.step() optim_encoder.zero_grad() batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].float().to(args.device) dataxp = batchx[1].float().to(args.device) dloss, closs = compute_loss(datax, dataxp, encoder, contrastive, args.device) (args.ld * dloss + closs).backward() optim_encoder.step() if i % args.evaluate == 0: encoder.eval() contrastive.eval() print('Iter: {}'.format(i), end=': ') evaluate(args.visualiser, datax, dataxp, 'x') _acc = evaluate_accuracy(args.visualiser, i, test_loader1, encoder, args.nc, 'x', args.device) print('disc loss: {}'.format( (ploss - nloss).detach().cpu().numpy()), end='\t') print('gp: {}'.format(gp.detach().cpu().numpy()), end='\t') print('positive dist loss: {}'.format( dloss.detach().cpu().numpy()), end='\t') print('contrast. loss: {}'.format(closs.detach().cpu().numpy()), end='\t') print('Accuracy: {}'.format(_acc)) t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) ssx = args.ssx.to(args.device) ssx.eval() zxs, labelsx = get_initial_zx(train_loader1, ssx, args.device) zys, labelsy = get_initial_zx(train_loader2, ssx, args.device) sc = SpectralClustering(args.nc, affinity='sigmoid', gamma=1.7) clusters = sc.fit_predict(zxs.cpu().numpy()) clusters = torch.from_numpy(clusters).to(args.device) classifier = models['classifier'].to(args.device) discriminator = models['discriminator'].to(args.device) classifier.apply(he_init) discriminator.apply(he_init) print(classifier) print(discriminator) optim_discriminator = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_classifier = optim.Adam(classifier.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optims = { 'optim_discriminator': optim_discriminator, 'optim_classifier': optim_classifier } iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() for i in range(iteration, args.iterations): classifier.train() discriminator.train() perm = torch.randperm(len(zxs)) ix = perm[:args.train_batch_size] zx = zxs[ix] perm = torch.randperm(len(zys)) iy = perm[:args.train_batch_size] zy = zys[iy] optim_discriminator.zero_grad() d_loss = disc_loss(zx, zy, discriminator, classifier.x, classifier.mlp, args.device) d_loss.backward() optim_discriminator.step() perm = torch.randperm(len(zxs)) ix = perm[:args.train_batch_size] zx = zxs[ix] label = clusters[ix].long() perm = torch.randperm(len(zys)) iy = perm[:args.train_batch_size] zy = zys[iy] optim_classifier.zero_grad() c_loss = classification_loss(zx, label, classifier) tcw_loss = classification_target_loss(zy, classifier) dw_loss = embed_div_loss(zx, zy, discriminator, classifier.x, classifier.mlp, args.device) m_loss1 = mixup_loss(zx, classifier, args.device) m_loss2 = mixup_loss(zy, classifier, args.device) (args.cw * c_loss).backward() (args.tcw * tcw_loss).backward() (args.dw * dw_loss).backward() (args.smw * m_loss1).backward() (args.tmw * m_loss2).backward() optim_classifier.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) classifier.eval() class_map = evaluate_cluster(args.visualiser, i, args.nc, zxs, labelsx, classifier, f'x', args.device) evaluate_cluster_accuracy(args.visualiser, i, zxs, labelsx, class_map, classifier, f'x', args.device) evaluate_cluster_accuracy(args.visualiser, i, zys, labelsy, class_map, classifier, f'y', args.device) save_path = args.save_path with open(os.path.join(save_path, 'c_loss'), 'a') as f: f.write(f'{i},{c_loss.cpu().item()}\n') with open(os.path.join(save_path, 'tcw_loss'), 'a') as f: f.write(f'{i},{tcw_loss.cpu().item()}\n') with open(os.path.join(save_path, 'dw_loss'), 'a') as f: f.write(f'{i},{dw_loss.cpu().item()}\n') with open(os.path.join(save_path, 'm_loss1'), 'a') as f: f.write(f'{i},{m_loss1.cpu().item()}\n') with open(os.path.join(save_path, 'm_loss2'), 'a') as f: f.write(f'{i},{m_loss2.cpu().item()}\n') with open(os.path.join(save_path, 'd_loss2'), 'a') as f: f.write(f'{i},{d_loss.cpu().item()}\n') args.visualiser.plot(c_loss.cpu().detach().numpy(), title='Source classifier loss', step=i) args.visualiser.plot(tcw_loss.cpu().detach().numpy(), title='Target classifier cross entropy', step=i) args.visualiser.plot(dw_loss.cpu().detach().numpy(), title='Classifier marginal divergence', step=i) args.visualiser.plot(m_loss1.cpu().detach().numpy(), title='Source mix up loss', step=i) args.visualiser.plot(m_loss2.cpu().detach().numpy(), title='Target mix up loss', step=i) args.visualiser.plot(d_loss.cpu().detach().numpy(), title='Discriminator loss', step=i) t0 = time.time() save_models(models, i, args.model_path, args.evaluate) save_models(optims, i, args.model_path, args.evaluate)
def train(args): parameters = vars(args) valid_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) critic = models['critic'].to(args.device) eval = args.evaluation.eval().to(args.device) print(generator) print(critic) optim_critic = optim.Adam(critic.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter2 = iter(train_loader2) titer, titer2 = iter(test_loader1), iter(test_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() critic.train() for _ in range(args.d_updates): batch, iter2 = sample(iter2, train_loader2) data = batch[0].to(args.device) label = corrupt(batch[1], args.nc, args.corrupt_tgt) label = one_hot_embedding(label, args.nc).to(args.device) optim_critic.zero_grad() pos_loss, neg_loss, gp = critic_loss(data, label, args.z_dim, critic, generator, args.device) pos_loss.backward() neg_loss.backward(mone) (10 * gp).backward() optim_critic.step() optim_generator.zero_grad() t_loss = transfer_loss(data.shape[0], label, args.z_dim, critic, generator, args.device) t_loss.backward() optim_generator.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) generator.eval() batch, titer = sample(titer, test_loader1) data1 = batch[0].to(args.device) label = one_hot_embedding(batch[1], args.nc).to(args.device) batch, titer = sample(titer2, test_loader2) data2 = batch[0].to(args.device) plot_transfer(args.visualiser, label, args.nc, data1, data2, args.nz, generator, args.device, i) save_path = args.save_path eval_accuracy = evaluate(valid_loader1, args.nz, args.nc, args.corrupt_src, generator, eval, args.device) test_accuracy = evaluate(test_loader1, args.nz, args.nc, args.corrupt_src, generator, eval, args.device) with open(os.path.join(save_path, 'critic_loss'), 'a') as f: f.write(f'{i},{(pos_loss-neg_loss).cpu().item()}\n') with open(os.path.join(save_path, 'tloss'), 'a') as f: f.write(f'{i},{t_loss.cpu().item()}\n') with open(os.path.join(save_path, 'eval_accuracy'), 'a') as f: f.write(f'{i},{eval_accuracy}\n') with open(os.path.join(save_path, 'test_accuracy'), 'a') as f: f.write(f'{i},{eval_accuracy}\n') args.visualiser.plot((pos_loss - neg_loss).cpu().detach().numpy(), title='critic_loss', step=i) args.visualiser.plot(t_loss.cpu().detach().numpy(), title='tloss', step=i) args.visualiser.plot(eval_accuracy, title=f'Validation transfer accuracy', step=i) args.visualiser.plot(test_accuracy, title=f'Test transfer accuracy', step=i) t0 = time.time() save_models(models, 0, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) classifier = models['classifier'].to(args.device) discriminator = models['discriminator'].to(args.device) print(classifier) print(discriminator) optim_discriminator = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_classifier = optim.Adam(classifier.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optims = { 'optim_discriminator': optim_discriminator, 'optim_classifier': optim_classifier } initialize(optims, args.reload, args.save_path, args.model_path) iter1 = iter(train_loader1) iter2 = iter(train_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() for i in range(iteration, args.iterations): classifier.train() discriminator.train() batchx, iter1 = sample(iter1, train_loader1) data1 = batchx[0].to(args.device) if data1.shape[0] != args.train_batch_size: batchx, iter1 = sample(iter1, train_loader1) data1 = batchx[0].to(args.device) label = batchx[1].to(args.device) batchy, iter2 = sample(iter2, train_loader2) data2 = batchy[0].to(args.device) if data2.shape[0] != args.train_batch_size: batchy, iter2 = sample(iter2, train_loader2) data2 = batchy[0].to(args.device) optim_discriminator.zero_grad() d_loss = disc_loss(data1, data2, discriminator, classifier.x, classifier.mlp, args.device) d_loss.backward() optim_discriminator.step() optim_classifier.zero_grad() c_loss = classification_loss(data1, label, classifier) tcw_loss = classification_target_loss(data2, classifier) dw_loss = embed_div_loss(data1, data2, discriminator, classifier.x, classifier.mlp, args.device) v_loss1 = vat_loss(data1, classifier, args.radius, args.device) v_loss2 = vat_loss(data2, classifier, args.radius, args.device) m_loss1 = mixup_loss(data1, classifier, args.device) m_loss2 = mixup_loss(data2, classifier, args.device) (args.cw * c_loss).backward() (args.tcw * tcw_loss).backward() (args.dw * dw_loss).backward() (args.svw * v_loss1).backward() (args.tvw * v_loss2).backward() (args.smw * m_loss1).backward() (args.tmw * m_loss2).backward() optim_classifier.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) classifier.eval() test_accuracy_x = evaluate(test_loader1, classifier, args.device) test_accuracy_y = evaluate(test_loader2, classifier, args.device) save_path = args.save_path with open(os.path.join(save_path, 'c_loss'), 'a') as f: f.write(f'{i},{c_loss.cpu().item()}\n') with open(os.path.join(save_path, 'tcw_loss'), 'a') as f: f.write(f'{i},{tcw_loss.cpu().item()}\n') with open(os.path.join(save_path, 'dw_loss'), 'a') as f: f.write(f'{i},{dw_loss.cpu().item()}\n') with open(os.path.join(save_path, 'v_loss1'), 'a') as f: f.write(f'{i},{v_loss1.cpu().item()}\n') with open(os.path.join(save_path, 'v_loss2'), 'a') as f: f.write(f'{i},{v_loss2.cpu().item()}\n') with open(os.path.join(save_path, 'm_loss1'), 'a') as f: f.write(f'{i},{m_loss1.cpu().item()}\n') with open(os.path.join(save_path, 'm_loss2'), 'a') as f: f.write(f'{i},{m_loss2.cpu().item()}\n') with open(os.path.join(save_path, 'd_loss2'), 'a') as f: f.write(f'{i},{d_loss.cpu().item()}\n') #with open(os.path.join(save_path, 'eval_accuracy_x'), 'a') as f: f.write(f'{i},{eval_accuracy_x}\n') #with open(os.path.join(save_path, 'eval_accuracy_y'), 'a') as f: f.write(f'{i},{eval_accuracy_y}\n') with open(os.path.join(save_path, 'test_accuracy_x'), 'a') as f: f.write(f'{i},{test_accuracy_x}\n') with open(os.path.join(save_path, 'test_accuracy_y'), 'a') as f: f.write(f'{i},{test_accuracy_y}\n') args.visualiser.plot(c_loss.cpu().detach().numpy(), title='Source classifier loss', step=i) args.visualiser.plot(tcw_loss.cpu().detach().numpy(), title='Target classifier cross entropy', step=i) args.visualiser.plot(dw_loss.cpu().detach().numpy(), title='Classifier marginal divergence', step=i) args.visualiser.plot(v_loss1.cpu().detach().numpy(), title='Source virtual adversarial loss', step=i) args.visualiser.plot(v_loss2.cpu().detach().numpy(), title='Target virtual adversarial loss', step=i) args.visualiser.plot(m_loss1.cpu().detach().numpy(), title='Source mix up loss', step=i) args.visualiser.plot(m_loss2.cpu().detach().numpy(), title='Target mix up loss', step=i) args.visualiser.plot(d_loss.cpu().detach().numpy(), title='Discriminator loss', step=i) #args.visualiser.plot(eval_accuracy_x, title='Eval acc X', step=i) #args.visualiser.plot(eval_accuracy_y, title='Eval acc Y', step=i) args.visualiser.plot(test_accuracy_x, title='Test acc X', step=i) args.visualiser.plot(test_accuracy_y, title='Test acc Y', step=i) t0 = time.time() save_models(models, i, args.model_path, args.evaluate) save_models(optims, i, args.model_path, args.evaluate)
def train(args): parameters = vars(args) train_loader1, valid_loader1, test_loader1 = args.loaders1 train_loader2, valid_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) critic = models['critic'].to(args.device) generator = models['generator'].to(args.device) evalY = args.evalY.to(args.device).eval() semantic = args.semantic.to(args.device).eval() print(generator) print(critic) print(semantic) optim_critic = optim.Adam(critic.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iter2 = iter(train_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() for i in range(iteration, args.iterations): critic.train() generator.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) critic_lossy = compute_critic_loss(datax, args.z_dim, datay, critic, generator, args.device) optim_critic.zero_grad() critic_lossy.backward() optim_critic.step() batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) glossxy = generator_loss(datax, args.z_dim, critic, generator, args.device) slossxy = semantic_loss(datax, args.z_dim, generator, semantic, args.device) optim_generator.zero_grad() glossxy.backward() (args.gsxy * slossxy).backward() optim_generator.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) generator.eval() save_path = args.save_path plot_transfer(args.visualiser, datax, datay, args.z_dim, generator, 'x-y', i, args.device) test_accuracy_xy = evaluate(test_loader1, args.z_dim, generator, evalY, args.device) with open(os.path.join(save_path, 'glossxy'), 'a') as f: f.write(f'{i},{glossxy.cpu().item()}\n') with open(os.path.join(save_path, 'slossxy'), 'a') as f: f.write(f'{i},{slossxy.cpu().item()}\n') with open(os.path.join(save_path, 'test_accuracy_xy'), 'a') as f: f.write(f'{i},{test_accuracy_xy}\n') args.visualiser.plot(critic_lossy.cpu().detach().numpy(), title='critic_lossy', step=i) args.visualiser.plot(glossxy.cpu().detach().numpy(), title='glossxy', step=i) args.visualiser.plot(slossxy.cpu().detach().numpy(), title='slossxy', step=i) args.visualiser.plot(test_accuracy_xy, title=f'Test transfer accuracy X-Y', step=i) t0 = time.time() save_models(models, 0, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) classifier = models['classifier'].to(args.device) discriminator = models['discriminator'].to(args.device) cluster = args.cluster.eval().to(args.device) print(classifier) print(discriminator) optim_classifier = optim.Adam(classifier.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_discriminator = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iter2 = iter(train_loader2) titer1 = iter(test_loader1) titer2 = iter(test_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() for i in range(iteration, args.iterations): classifier.train() discriminator.train() batchx, iter1 = sample(iter1, train_loader1) data1 = batchx[0].to(args.device) if data1.shape[0] != args.train_batch_size: batchx, iter1 = sample(iter1, train_loader1) data1 = batchx[0].to(args.device) label = cluster(data1).argmax(1).detach() batchy, iter2 = sample(iter2, train_loader2) data2 = batchy[0].to(args.device) if data2.shape[0] != args.train_batch_size: batchy, iter2 = sample(iter2, train_loader2) data2 = batchy[0].to(args.device) optim_discriminator.zero_grad() d_loss = disc_loss(data1, data2, discriminator, classifier.x, classifier.mlp, args.device) d_loss.backward() optim_discriminator.step() optim_classifier.zero_grad() c_loss = classification_loss(data1, label, classifier) tcw_loss = classification_target_loss(data2, classifier) dw_loss = embed_div_loss(data1, data2, discriminator, classifier.x, classifier.mlp, args.device) v_loss1 = vat_loss(data1, classifier, args.radius, args.device) v_loss2 = vat_loss(data2, classifier, args.radius, args.device) m_loss1 = mixup_loss(data1, classifier, args.device) m_loss2 = mixup_loss(data2, classifier, args.device) (args.cw * c_loss).backward() (args.tcw * tcw_loss).backward() (args.dw * dw_loss).backward() (args.svw * v_loss1).backward() (args.tvw * v_loss2).backward() (args.smw * m_loss1).backward() (args.tmw * m_loss2).backward() optim_classifier.step() if i % args.evaluate == 0: classifier.eval() batchx, titer1 = sample(titer1, test_loader1) data1 = batchx[0].to(args.device) batchy, titer2 = sample(titer2, test_loader2) data2 = batchy[0].to(args.device) print('Iter: %s' % i, time.time() - t0) class_map = evaluate_cluster(args.visualiser, i, args.nc, test_loader1, cluster, f'x', args.device) evaluate_cluster_accuracy(args.visualiser, i, test_loader1, class_map, classifier, f'x', args.device) evaluate_cluster_accuracy(args.visualiser, i, test_loader2, class_map, classifier, f'y', args.device) args.visualiser.plot(c_loss.cpu().detach().numpy(), title=f'Classifier loss', step=i) t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) print(generator) optim_generator = optim.SGD(generator.parameters(), lr=args.lr) #, betas=(args.beta1, args.beta2)) iter1, iter2 = iter(train_loader1), iter(train_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1, titer2 = iter(test_loader1), iter(test_loader2) t0 = time.time() for i in range(iteration, args.iterations): generator.train() batchx, iter1 = sample(iter1, train_loader1) data = batchx.to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy.to(args.device) optim_generator.zero_grad() x = generator(data) t_lossx = sinkhorn_loss(x, data, args.eps, data.shape[0], 100, args.device, p=args.lp)**args.p_exp t_lossy = sinkhorn_loss(x, datay, args.eps, data.shape[0], 100, args.device, p=args.lp)**args.p_exp ((1 - args.alphat) * t_lossx + args.alphat * t_lossy).backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) batchx, titer1 = sample(titer1, test_loader1) datax = batchx.to(args.device) batchy, titer2 = sample(titer2, test_loader2) datay = batchy.to(args.device) evaluate(args.visualiser, datax, datax, datay, generator, 'x') args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss y') args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss x') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, valid_loader1, test_loader1 = args.loaders1 train_loader2, valid_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) criticX = models['criticX'].to(args.device) criticY = models['criticY'].to(args.device) generatorXY = models['generatorXY'].to(args.device) generatorYX = models['generatorYX'].to(args.device) evalX = args.evalX.to(args.device).eval() evalY = args.evalY.to(args.device).eval() classifier = args.classifier.to(args.device).eval() print(generatorXY) print(criticX) print(classifier) optim_criticX = optim.Adam(criticX.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticY = optim.Adam(criticY.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generatorXY = optim.Adam(generatorXY.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generatorYX = optim.Adam(generatorYX.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iter2 = iter(train_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) t0 = time.time() for i in range(iteration, args.iterations): criticX.train() criticY.train() generatorXY.train() generatorYX.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].to(args.device) if datax.shape[0] != args.train_batch_size: batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) if datay.shape[0] != args.train_batch_size: batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) #critic_lossx = compute_critic_loss(datay, args.z_dim, datax, criticX, generatorYX, args.device) #optim_criticX.zero_grad() #critic_lossx.backward() #optim_criticX.step() critic_lossy = compute_critic_loss(datax, args.z_dim, datay, criticY, generatorXY, args.device) optim_criticY.zero_grad() critic_lossy.backward() optim_criticY.step() batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].to(args.device) if datax.shape[0] != args.train_batch_size: batchx, iter1 = sample(iter1, train_loader1) datax = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) if datay.shape[0] != args.train_batch_size: batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) glossxy = generator_loss(datax, args.z_dim, criticY, generatorXY, args.device) slossxy = semantic_loss(datax, args.z_dim, generatorXY, classifier, args.device) #cyclelossxy = cycle_loss(datax, args.z_dim, generatorXY, generatorYX, args.device) #idlossxy = identity_loss(datax, args.z_dim, generatorXY, args.device) optim_generatorXY.zero_grad() glossxy.backward() (args.gsxy * slossxy).backward() #(args.gcxy*cyclelossxy).backward() #(args.gixy*idlossxy).backward() optim_generatorXY.step() #glossyx = generator_loss(datay, args.z_dim, criticX, generatorYX, args.device) #slossyx = semantic_loss(datay, args.z_dim, generatorYX, classifier, args.device) #cyclelossyx = cycle_loss(datay, args.z_dim, generatorYX, generatorXY, args.device) #idlossyx = identity_loss(datay, args.z_dim, generatorYX, args.device) #optim_generatorYX.zero_grad() #glossyx.backward() #(args.gsyx*slossyx).backward() #(args.gcyx*cyclelossyx).backward() #(args.giyx*idlossyx).backward() #optim_generatorYX.step() if i % args.evaluate == 0: print('Iter: %s' % i, time.time() - t0) generatorXY.eval() generatorYX.eval() save_path = args.save_path plot_transfer(args.visualiser, datax, datay, args.z_dim, generatorXY, 'x-y', i, args.device) #plot_transfer(args.visualiser, datay, datax, args.z_dim, generatorYX, 'y-x', i, args.device) #eval_accuracy_xy = evaluate(valid_loader1, args.z_dim, generatorXY, evalY, args.device) #eval_accuracy_yx = evaluate(valid_loader2, args.z_dim, generatorYX, evalX, args.device) test_accuracy_xy = evaluate(test_loader1, args.z_dim, generatorXY, evalY, args.device) #test_accuracy_yx = evaluate(test_loader2, args.z_dim, generatorYX, evalX, args.device) #with open(os.path.join(save_path, 'critic_lossx'), 'a') as f: f.write(f'{i},{critic_lossx.cpu().item()}\n') #with open(os.path.join(save_path, 'critic_lossy'), 'a') as f: f.write(f'{i},{critic_lossy.cpu().item()}\n') with open(os.path.join(save_path, 'glossxy'), 'a') as f: f.write(f'{i},{glossxy.cpu().item()}\n') #with open(os.path.join(save_path, 'glossyx'), 'a') as f: f.write(f'{i},{glossyx.cpu().item()}\n') with open(os.path.join(save_path, 'slossxy'), 'a') as f: f.write(f'{i},{slossxy.cpu().item()}\n') #with open(os.path.join(save_path, 'slossyx'), 'a') as f: f.write(f'{i},{slossyx.cpu().item()}\n') #with open(os.path.join(save_path, 'idlossxy'), 'a') as f: f.write(f'{i},{idlossxy.cpu().item()}') #with open(os.path.join(save_path, 'idlossyx'), 'a') as f: f.write(f'{i},{idlossyx.cpu().item()}') #with open(os.path.join(save_path, 'cyclelossxy'), 'a') as f: f.write(f'{i},{cyclelossxy.cpu().item()}\n') #with open(os.path.join(save_path, 'cyclelossyx'), 'a') as f: f.write(f'{i},{cyclelossyx.cpu().item()}\n') #with open(os.path.join(save_path, 'eval_accuracy_xy'), 'a') as f: f.write(f'{i},{eval_accuracy_xy}\n') #with open(os.path.join(save_path, 'eval_accuracy_yx'), 'a') as f: f.write(f'{i},{eval_accuracy_yx}\n') with open(os.path.join(save_path, 'test_accuracy_xy'), 'a') as f: f.write(f'{i},{test_accuracy_xy}\n') #with open(os.path.join(save_path, 'test_accuracy_yx'), 'a') as f: f.write(f'{i},{test_accuracy_yx}\n') #args.visualiser.plot(critic_lossx.cpu().detach().numpy(), title='critic_lossx', step=i) args.visualiser.plot(critic_lossy.cpu().detach().numpy(), title='critic_lossy', step=i) args.visualiser.plot(glossxy.cpu().detach().numpy(), title='glossxy', step=i) #args.visualiser.plot(glossyx.cpu().detach().numpy(), title='glossyx', step=i) args.visualiser.plot(slossxy.cpu().detach().numpy(), title='slossxy', step=i) #args.visualiser.plot(slossyx.cpu().detach().numpy(), title='slossyx', step=i) #args.visualiser.plot(idlossxy.cpu().detach().numpy(), title='idlossxy', step=i) #args.visualiser.plot(idlossyx.cpu().detach().numpy(), title='idlossyx', step=i) #args.visualiser.plot(cyclelossxy.cpu().detach().numpy(), title='cyclelossxy', step=i) #args.visualiser.plot(cyclelossyx.cpu().detach().numpy(), title='cyclelossyx', step=i) #args.visualiser.plot(eval_accuracy_xy, title=f'Validation transfer accuracy X-Y', step=i) #args.visualiser.plot(eval_accuracy_yx, title=f'Validation transfer accuracy Y-X', step=i) args.visualiser.plot(test_accuracy_xy, title=f'Test transfer accuracy X-Y', step=i) #args.visualiser.plot(test_accuracy_yx, title=f'Test transfer accuracy Y-X', step=i) t0 = time.time() save_models(models, 0, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) critic = models['critic'].to(args.device) encoder = models['encoder'].to(args.device) print(generator) print(critic) print(encoder) optim_critic = optim.Adam(critic.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_encoder = optim.Adam(encoder.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1 = iter(train_loader1) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1 = iter(test_loader1) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() critic.train() encoder.train() for _ in range(10): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) optim_encoder.zero_grad() optim_generator.zero_grad() e_loss = encoder_loss(data.shape[0], args.lp, args.z_dim, encoder, generator, critic, args.device) e_loss.backward() optim_encoder.step() optim_generator.step() for _ in range(1): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) optim_critic.zero_grad() r_loss = critic_loss(data, args.lp, args.z_dim, encoder, critic, generator, args.device) r_loss.backward(mone) optim_critic.step() for _ in range(1): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) optim_generator.zero_grad() t_loss = transfer_loss(data.shape[0], args.lp, args.z_dim, encoder, critic, generator, args.device) t_loss.backward() optim_generator.step() if i % args.evaluate == 0: encoder.eval() generator.eval() print('Iter: %s' % i, time.time() - t0) batchx, titer1 = sample(titer1, test_loader1) data = batchx[0].to(args.device) evaluate(args.visualiser, args.z_dim, data, encoder, generator, critic, args.z_dim, i, args.device) d_loss = (r_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss') args.visualiser.plot(step=i, data=e_loss.detach().cpu().numpy(), title=f'Encoder loss') args.visualiser.plot(step=i, data=t_loss.detach().cpu().numpy(), title=f'Generator loss') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 train_loader3, test_loader3 = args.loaders3 train_loader4, test_loader4 = args.loaders4 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) criticx1 = models['criticx1'].to(args.device) criticx2 = models['criticx2'].to(args.device) criticy1 = models['criticy1'].to(args.device) criticy2 = models['criticy2'].to(args.device) criticz1 = models['criticz1'].to(args.device) criticz2 = models['criticz2'].to(args.device) print(generator) print(criticx1) print(criticy1) optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticz1 = optim.Adam(criticz1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticz2 = optim.Adam(criticz2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1, iter2, iter3, iter4 = iter(train_loader1), iter(train_loader2), iter(train_loader3), iter(train_loader4) iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1, titer2, titer3, titer4 = iter(test_loader1), iter(test_loader2), iter(test_loader3), iter(test_loader4) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() generator.train() criticx1.train() criticx2.train() criticy1.train() criticy2.train() criticz1.train() criticz2.train() for i in range(0): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) datay[:,1] = datay[:,1]*0 datay[:,2] = datay[:,2]*0 batchz, iter3 = sample(iter3, train_loader3) dataz = batchz[0].to(args.device) dataz[:,0] = dataz[:,0]*0 dataz[:,2] = dataz[:,2]*0 batchw, iter4 = sample(iter4, train_loader4) dataw = batchw[0].to(args.device) dataw[:,0] = dataw[:,0]*0 dataw[:,1] = dataw[:,1]*0 optim_criticx1.zero_grad() optim_criticx2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticx1, criticx2) (r_loss + g_loss + p).backward(mone) optim_criticx1.step() optim_criticx2.step() optim_criticy1.zero_grad() optim_criticy2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, dataz, args.eps, args.lp, criticy1, criticy2) (r_loss + g_loss + p).backward(mone) optim_criticy1.step() optim_criticy2.step() optim_criticz1.zero_grad() optim_criticz2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, dataw, args.eps, args.lp, criticz1, criticz2) (r_loss + g_loss + p).backward(mone) optim_criticz1.step() optim_criticz2.step() if i % 100 == 0: print(f'Critics-{i}') print('Iter: %s' % i, time.time() - t0) args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty') d_loss = (r_loss+g_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y') t0 = time.time() for i in range(iteration, args.iterations): generator.train() criticx1.train() criticx2.train() criticy1.train() criticy2.train() criticz1.train() criticz2.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) datay[:,1] = datay[:,1]*0 datay[:,2] = datay[:,2]*0 batchz, iter3 = sample(iter3, train_loader3) dataz = batchz[0].to(args.device) dataz[:,0] = dataz[:,0]*0 dataz[:,2] = dataz[:,2]*0 batchw, iter4 = sample(iter4, train_loader4) dataw = batchw[0].to(args.device) dataw[:,0] = dataw[:,0]*0 dataw[:,1] = dataw[:,1]*0 optim_criticx1.zero_grad() optim_criticx2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticx1, criticx2) (r_loss + g_loss + p).backward(mone) optim_criticx1.step() optim_criticx2.step() optim_criticy1.zero_grad() optim_criticy2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, dataz, args.eps, args.lp, criticy1, criticy2) (r_loss + g_loss + p).backward(mone) optim_criticy1.step() optim_criticy2.step() optim_criticz1.zero_grad() optim_criticz2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, dataw, args.eps, args.lp, criticz1, criticz2) (r_loss + g_loss + p).backward(mone) optim_criticz1.step() optim_criticz2.step() batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) datay[:,1] = datay[:,1]*0 datay[:,2] = datay[:,2]*0 batchz, iter3 = sample(iter3, train_loader3) dataz = batchz[0].to(args.device) dataz[:,0] = dataz[:,0]*0 dataz[:,2] = dataz[:,2]*0 batchw, iter4 = sample(iter4, train_loader4) dataw = batchw[0].to(args.device) dataw[:,0] = dataw[:,0]*0 dataw[:,1] = dataw[:,1]*0 optim_generator.zero_grad() t_ = Dirichlet(torch.FloatTensor([1.,1.,1.])).sample().to(args.device) t = torch.stack([t_]*data.shape[0]) t_lossx = transfer_loss(data, datay, t, args.eps, args.lp, criticx1, criticx2, generator) t_lossy = transfer_loss(data, dataz, t, args.eps, args.lp, criticy1, criticy2, generator) t_lossz = transfer_loss(data, dataw, t, args.eps, args.lp, criticz1, criticz2, generator) t_loss = (t_[0]*t_lossx + t_[1]*t_lossy + t_[2]*t_lossz).sum() t_loss.backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) #batchx, titer1 = sample(titer1, test_loader1) #datax = batchx[0].to(args.device) #batchy, titer2 = sample(titer2, test_loader2) #datay = batchy[0].to(args.device) evaluate(args.visualiser, data, datay, dataz, dataw, generator, 'x', args.device) args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss X') args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss Y') args.visualiser.plot(step=i, data=t_lossz.detach().cpu().numpy(), title=f'Generator loss Z') args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty') d_loss = (r_loss+g_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) criticx1 = models['criticx1'].to(args.device) criticx2 = models['criticx2'].to(args.device) criticy1 = models['criticy1'].to(args.device) criticy2 = models['criticy2'].to(args.device) print(generator) print(criticx1) print(criticx2) print(criticy1) print(criticy2) optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1, iter2 = iter(train_loader1), iter(train_loader2) iteration = infer_iteration(list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1, titer2 = iter(test_loader1), iter(test_loader2) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() for i in range(iteration, args.iterations): generator.train() criticx1.train() criticx2.train() criticy1.train() criticy2.train() for _ in range(args.d_updates): batchx, iter1 = sample(iter1, train_loader1) data = batchx.to(args.device) optim_criticx1.zero_grad() optim_criticx2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, data, args.z_dim, args.eps, args.lp, criticx1, criticx2, generator, args.device) (r_loss + g_loss + p).backward(mone) optim_criticx1.step() optim_criticx2.step() batchy, iter2 = sample(iter2, train_loader2) datay = batchy.to(args.device) optim_criticy1.zero_grad() optim_criticy2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, datay, args.z_dim, args.eps, args.lp, criticy1, criticy2, generator, args.device) (r_loss + g_loss + p).backward(mone) optim_criticy1.step() optim_criticy2.step() optim_generator.zero_grad() t_lossx = transfer_loss(data, data, args.z_dim, args.eps, args.lp, criticx1, criticx2, generator, args.device) t_lossy = transfer_loss(data, datay, args.z_dim, args.eps, args.lp, criticy1, criticy2, generator, args.device) (0.5*t_lossx + 0.5*t_lossy).backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) batchx, titer1 = sample(titer1, test_loader1) datax = batchx.to(args.device) batchy, titer2 = sample(titer2, test_loader2) datay = batchy.to(args.device) data = torch.randn(args.test_batch_size, args.z_dim, device=args.device) evaluate(args.visualiser, data, datax, datay, generator, 'x') d_loss = (r_loss+g_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss') args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss y') args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss x') args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)
def train(args): parameters = vars(args) train_loader1, test_loader1 = args.loaders1 train_loader2, test_loader2 = args.loaders2 models = define_models(**parameters) initialize(models, args.reload, args.save_path, args.model_path) generator = models['generator'].to(args.device) criticx1 = models['criticx1'].to(args.device) criticx2 = models['criticx2'].to(args.device) criticy1 = models['criticy1'].to(args.device) criticy2 = models['criticy2'].to(args.device) print(generator) print(criticx1) print(criticy1) optim_criticx1 = optim.Adam(criticx1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticx2 = optim.Adam(criticx2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy1 = optim.Adam(criticy1.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_criticy2 = optim.Adam(criticy2.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) optim_generator = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) iter1, iter2 = iter(train_loader1), iter(train_loader2) iteration = infer_iteration( list(models.keys())[0], args.reload, args.model_path, args.save_path) titer1, titer2 = iter(test_loader1), iter(test_loader2) mone = torch.FloatTensor([-1]).to(args.device) t0 = time.time() generator.train() criticx1.train() criticx2.train() criticy1.train() criticy2.train() for i in range(4000): batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) optim_criticx1.zero_grad() optim_criticx2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, data, args.eps, args.lp, criticx1, criticx2) (r_loss + g_loss + p).backward(mone) optim_criticx1.step() optim_criticx2.step() optim_criticy1.zero_grad() optim_criticy2.zero_grad() r_loss, g_loss, p = disc_loss_generation(data, datay, args.eps, args.lp, criticy1, criticy2) (r_loss + g_loss + p).backward(mone) optim_criticy1.step() optim_criticy2.step() if i % 100 == 0: print(f'Critics-{i}') print('Iter: %s' % i, time.time() - t0) args.visualiser.plot(step=i, data=p.detach().cpu().numpy(), title=f'Penalty') d_loss = (r_loss + g_loss).detach().cpu().numpy() args.visualiser.plot(step=i, data=d_loss, title=f'Critic loss Y') t0 = time.time() for i in range(iteration, args.iterations): generator.train() criticx1.train() criticx2.train() criticy1.train() criticy2.train() batchx, iter1 = sample(iter1, train_loader1) data = batchx[0].to(args.device) #if data.shape[0] != args.train_batch_size: #batchx, iter1 = sample(iter1, train_loader1) #data = batchx[0].to(args.device) batchy, iter2 = sample(iter2, train_loader2) datay = batchy[0].to(args.device) #if datay.shape[0] != args.train_batch_size: #batchy, iter2 = sample(iter2, train_loader2) #datay = batchy[0].to(args.device) optim_generator.zero_grad() t_ = torch.rand(args.nt, device=args.device) t = torch.stack([t_] * data.shape[0]) #t = torch.stack([t_] * input_data.shape[0]).transpose(0, 1).reshape(-1, 1) tdata = torch.cat([data] * args.nt) tdatay = torch.cat([datay] * args.nt) t_lossx = transfer_loss(tdata, tdata, args.nt, t, args.eps, args.lp, criticx1, criticx2, generator) t_lossy = transfer_loss(tdata, tdatay, args.nt, t, args.eps, args.lp, criticy1, criticy2, generator) t_loss = ((1 - t_) * t_lossx + t_ * t_lossy).sum() t_loss.backward() optim_generator.step() if i % args.evaluate == 0: generator.eval() print('Iter: %s' % i, time.time() - t0) batchx, titer1 = sample(titer1, test_loader1) datax = batchx[0].to(args.device) batchy, titer2 = sample(titer2, test_loader2) datay = batchy[0].to(args.device) evaluate(args.visualiser, datax, datay, generator, 'x', args.device) args.visualiser.plot(step=i, data=t_lossx.detach().cpu().numpy(), title=f'Generator loss X') args.visualiser.plot(step=i, data=t_lossy.detach().cpu().numpy(), title=f'Generator loss Y') t0 = time.time() save_models(models, i, args.model_path, args.checkpoint)