Exemplo n.º 1
0
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.global_iter_cls = 0
        self.pbar = tqdm(total=self.max_iter)
        self.pbar_cls = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.eval_batch_size = args.eval_batch_size
        self.data_loader = return_data(args, 0)
        self.data_loader_eval = return_data(args, 2)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.beta = args.beta

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D
        self.alpha = args.alpha
        self.beta = args.beta
        self.grl = args.grl

        self.lr_cls = args.lr_cls
        self.beta1_cls = args.beta1_D
        self.beta2_cls = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(),
                                    lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.pacls = classifier(30, 2).cuda()
        self.revcls = classifier(30, 2).cuda()
        self.tcls = classifier(30, 2).cuda()
        self.trevcls = classifier(30, 2).cuda()

        self.targetcls = classifier(59, 2).cuda()
        self.pa_target = classifier(30, 2).cuda()
        self.target_pa = paclassifier(1, 1).cuda()
        self.pa_pa = classifier(30, 2).cuda()

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.optim_pacls = optim.Adam(self.pacls.parameters(), lr=self.lr_D)

        self.optim_revcls = optim.Adam(self.revcls.parameters(), lr=self.lr_D)

        self.optim_tcls = optim.Adam(self.tcls.parameters(), lr=self.lr_D)
        self.optim_trevcls = optim.Adam(self.trevcls.parameters(),
                                        lr=self.lr_D)

        self.optim_cls = optim.Adam(self.targetcls.parameters(),
                                    lr=self.lr_cls)
        self.optim_pa_target = optim.Adam(self.pa_target.parameters(),
                                          lr=self.lr_cls)
        self.optim_target_pa = optim.Adam(self.target_pa.parameters(),
                                          lr=self.lr_cls)
        self.optim_pa_pa = optim.Adam(self.pa_pa.parameters(), lr=self.lr_cls)

        self.nets = [
            self.VAE, self.D, self.pacls, self.targetcls, self.revcls,
            self.pa_target, self.tcls, self.trevcls
        ]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z',
                           recon='win_recon',
                           kld='win_kld',
                           acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm',
                                      'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name + '/lines',
                                       win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir + "/cls")
        mkdirs(self.ckpt_dir + "/vae")

        if args.ckpt_load:

            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)
Exemplo n.º 2
0
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader, self.data = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.nets = [self.VAE, self.D]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(log_to_filename='./logging.log', offline=True)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name+'/lines', win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)