def __init__(self, args): # Misc use_cuda = args.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.name = args.name self.max_iter = int(args.max_iter) self.print_iter = args.print_iter self.global_iter = 0 self.global_iter_cls = 0 self.pbar = tqdm(total=self.max_iter) self.pbar_cls = tqdm(total=self.max_iter) # Data self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.eval_batch_size = args.eval_batch_size self.data_loader = return_data(args, 0) self.data_loader_eval = return_data(args, 2) # Networks & Optimizers self.z_dim = args.z_dim self.gamma = args.gamma self.beta = args.beta self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D self.alpha = args.alpha self.beta = args.beta self.grl = args.grl self.lr_cls = args.lr_cls self.beta1_cls = args.beta1_D self.beta2_cls = args.beta2_D if args.dataset == 'dsprites': self.VAE = FactorVAE1(self.z_dim).to(self.device) self.nc = 1 else: self.VAE = FactorVAE2(self.z_dim).to(self.device) self.nc = 3 self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE, betas=(self.beta1_VAE, self.beta2_VAE)) self.pacls = classifier(30, 2).cuda() self.revcls = classifier(30, 2).cuda() self.tcls = classifier(30, 2).cuda() self.trevcls = classifier(30, 2).cuda() self.targetcls = classifier(59, 2).cuda() self.pa_target = classifier(30, 2).cuda() self.target_pa = paclassifier(1, 1).cuda() self.pa_pa = classifier(30, 2).cuda() self.D = Discriminator(self.z_dim).to(self.device) self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D, betas=(self.beta1_D, self.beta2_D)) self.optim_pacls = optim.Adam(self.pacls.parameters(), lr=self.lr_D) self.optim_revcls = optim.Adam(self.revcls.parameters(), lr=self.lr_D) self.optim_tcls = optim.Adam(self.tcls.parameters(), lr=self.lr_D) self.optim_trevcls = optim.Adam(self.trevcls.parameters(), lr=self.lr_D) self.optim_cls = optim.Adam(self.targetcls.parameters(), lr=self.lr_cls) self.optim_pa_target = optim.Adam(self.pa_target.parameters(), lr=self.lr_cls) self.optim_target_pa = optim.Adam(self.target_pa.parameters(), lr=self.lr_cls) self.optim_pa_pa = optim.Adam(self.pa_pa.parameters(), lr=self.lr_cls) self.nets = [ self.VAE, self.D, self.pacls, self.targetcls, self.revcls, self.pa_target, self.tcls, self.trevcls ] # Visdom self.viz_on = args.viz_on self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc') self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc') self.image_gather = DataGather('true', 'recon') if self.viz_on: self.viz_port = args.viz_port self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_ra_iter = args.viz_ra_iter self.viz_ta_iter = args.viz_ta_iter if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['D_z']): self.viz_init() # Checkpoint self.ckpt_dir = os.path.join(args.ckpt_dir, args.name) self.ckpt_save_iter = args.ckpt_save_iter mkdirs(self.ckpt_dir + "/cls") mkdirs(self.ckpt_dir + "/vae") if args.ckpt_load: self.load_checkpoint(args.ckpt_load) # Output(latent traverse GIF) self.output_dir = os.path.join(args.output_dir, args.name) self.output_save = args.output_save mkdirs(self.output_dir)
def __init__(self, args): # Misc use_cuda = args.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.name = args.name self.max_iter = int(args.max_iter) self.print_iter = args.print_iter self.global_iter = 0 self.pbar = tqdm(total=self.max_iter) # Data self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader, self.data = return_data(args) # Networks & Optimizers self.z_dim = args.z_dim self.gamma = args.gamma self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D if args.dataset == 'dsprites': self.VAE = FactorVAE1(self.z_dim).to(self.device) self.nc = 1 else: self.VAE = FactorVAE2(self.z_dim).to(self.device) self.nc = 3 self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE, betas=(self.beta1_VAE, self.beta2_VAE)) self.D = Discriminator(self.z_dim).to(self.device) self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D, betas=(self.beta1_D, self.beta2_D)) self.nets = [self.VAE, self.D] # Visdom self.viz_on = args.viz_on self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc') self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc') self.image_gather = DataGather('true', 'recon') if self.viz_on: self.viz_port = args.viz_port self.viz = visdom.Visdom(log_to_filename='./logging.log', offline=True) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_ra_iter = args.viz_ra_iter self.viz_ta_iter = args.viz_ta_iter if not self.viz.win_exists(env=self.name+'/lines', win=self.win_id['D_z']): self.viz_init() # Checkpoint self.ckpt_dir = os.path.join(args.ckpt_dir, args.name) self.ckpt_save_iter = args.ckpt_save_iter mkdirs(self.ckpt_dir) if args.ckpt_load: self.load_checkpoint(args.ckpt_load) # Output(latent traverse GIF) self.output_dir = os.path.join(args.output_dir, args.name) self.output_save = args.output_save mkdirs(self.output_dir)