Esempio n. 1
0
File: snl.py Progetto: yyht/lfi
def main():
    task = "mg1"
    simulator, prior = simulators.get_simulator_and_prior(task)
    parameter_dim, observation_dim = (
        simulator.parameter_dim,
        simulator.observation_dim,
    )
    true_observation = simulator.get_ground_truth_observation()
    neural_likelihood = utils.get_neural_likelihood("maf", parameter_dim,
                                                    observation_dim)
    snl = SNL(
        simulator=simulator,
        true_observation=true_observation,
        prior=prior,
        neural_likelihood=neural_likelihood,
        mcmc_method="slice-np",
    )

    num_rounds, num_simulations_per_round = 10, 1000
    snl.run_inference(num_rounds=num_rounds,
                      num_simulations_per_round=num_simulations_per_round)

    samples = snl.sample_posterior(1000)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()).reshape(-1),
        lims=simulator.parameter_plotting_limits,
    )
    figure.savefig("./corner-posterior-snl.pdf")
Esempio n. 2
0
File: apt.py Progetto: yyht/lfi
def test_():
    task = "nonlinear-gaussian"
    simulator, prior = simulators.get_simulator_and_prior(task)
    parameter_dim, observation_dim = (
        simulator.parameter_dim,
        simulator.observation_dim,
    )
    true_observation = simulator.get_ground_truth_observation()
    neural_posterior = utils.get_neural_posterior(
        "maf", parameter_dim, observation_dim, simulator
    )
    apt = APT(
        simulator=simulator,
        true_observation=true_observation,
        prior=prior,
        neural_posterior=neural_posterior,
        num_atoms=-1,
        use_combined_loss=False,
        train_with_mcmc=False,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        discard_prior_samples=False,
    )

    num_rounds, num_simulations_per_round = 20, 1000
    apt.run_inference(
        num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round
    )

    samples = apt.sample_posterior(2500)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()
        ).reshape(-1),
        lims=simulator.parameter_plotting_limits,
    )
    figure.savefig(os.path.join(utils.get_output_root(), "corner-posterior-apt.pdf"))

    samples = apt.sample_posterior_mcmc(num_samples=1000)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()
        ).reshape(-1),
        lims=simulator.parameter_plotting_limits,
    )
    figure.savefig(
        os.path.join(utils.get_output_root(), "corner-posterior-apt-mcmc.pdf")
    )
Esempio n. 3
0
File: slice.py Progetto: yyht/lfi
def test_():
    # if torch.cuda.is_available():
    #     device = torch.device("cuda")
    #     torch.set_default_tensor_type("torch.cuda.FloatTensor")
    # else:
    #     input("CUDA not available, do you wish to continue?")
    #     device = torch.device("cpu")
    #     torch.set_default_tensor_type("torch.FloatTensor")

    loc = torch.Tensor([0, 0])
    covariance_matrix = torch.Tensor([[1, 0.99], [0.99, 1]])

    likelihood = distributions.MultivariateNormal(
        loc=loc, covariance_matrix=covariance_matrix)
    bound = 1.5
    low, high = -bound * torch.ones(2), bound * torch.ones(2)
    prior = distributions.Uniform(low=low, high=high)

    # def potential_function(inputs_dict):
    #     parameters = next(iter(inputs_dict.values()))
    #     return -(likelihood.log_prob(parameters) + prior.log_prob(parameters).sum())
    prior = distributions.Uniform(low=-5 * torch.ones(4),
                                  high=2 * torch.ones(4))
    from nsf import distributions as distributions_

    likelihood = distributions_.LotkaVolterraOscillating()
    potential_function = PotentialFunction(likelihood, prior)

    # kernel = Slice(potential_function=potential_function)
    from pyro.infer.mcmc import HMC, NUTS

    # kernel = HMC(potential_fn=potential_function)
    kernel = NUTS(potential_fn=potential_function)
    num_chains = 3
    sampler = MCMC(
        kernel=kernel,
        num_samples=10000 // num_chains,
        warmup_steps=200,
        initial_params={"": torch.zeros(num_chains, 4)},
        num_chains=num_chains,
    )
    sampler.run()
    samples = next(iter(sampler.get_samples().values()))

    utils.plot_hist_marginals(utils.tensor2numpy(samples),
                              ground_truth=utils.tensor2numpy(loc),
                              lims=[-6, 3])
    # plt.show()
    plt.savefig("/home/conor/Dropbox/phd/projects/lfi/out/mcmc.pdf")
    plt.close()
Esempio n. 4
0
def gridimshow(image, ax):
    if image.shape[0] == 1:
        image = utils.tensor2numpy(image[0, ...])
        ax.imshow(1 - image, cmap='Greys')
    else:
        image = utils.tensor2numpy(image.permute(1, 2, 0))
        ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')
Esempio n. 5
0
    def simulate(self, parameters):

        parameters = utils.tensor2numpy(parameters)

        assert parameters.shape[1] == 3, "parameter must be 3-dimensional"
        p1, p2, p3 = parameters[:, 0:1], parameters[:, 1:2], parameters[:, 2:3]
        N = parameters.shape[0]

        # service times (uniformly distributed)
        sts = (p2 - p1) * np.random.rand(N, self.n_sim_steparameters) + p1

        # inter-arrival times (exponentially distributed)
        iats = -np.log(1.0 - np.random.rand(N, self.n_sim_steparameters)) / p3

        # arrival times
        ats = np.cumsum(iats, axis=1)

        # inter-departure times
        idts = np.empty([N, self.n_sim_steparameters], dtype=float)
        idts[:, 0] = sts[:, 0] + ats[:, 0]

        # departure times
        dts = np.empty([N, self.n_sim_steparameters], dtype=float)
        dts[:, 0] = idts[:, 0]

        for i in range(1, self.n_sim_steparameters):
            idts[:, i] = sts[:, i] + np.maximum(0.0, ats[:, i] - dts[:, i - 1])
            dts[:, i] = dts[:, i - 1] + idts[:, i]

        self.num_total_simulations += N

        if self._summarize_observations:
            idts = self._summarizer(idts)

        return torch.Tensor(idts)
Esempio n. 6
0
    def simulate(self, parameters):
        """
        Generates observations for the given batch of parameters.

        :param parameters: torch.Tensor
            Batch of parameters.
        :return: torch.Tensor
            Batch of observations.
        """

        # Run simulator in NumPy.
        if isinstance(parameters, torch.Tensor):
            parameters = utils.tensor2numpy(parameters)

        # If we have a single parameter then view it as a batch of one.
        if parameters.ndim == 1:
            return self.simulate(parameters[None, ...])

        num_simulations = parameters.shape[0]

        # Keep track of total simulations.
        self.num_total_simulations += num_simulations

        # Run simulator.
        a = np.pi * np.random.rand(num_simulations) - np.pi / 2
        r = 0.01 * np.random.randn(num_simulations) + 0.1
        p = np.column_stack([r * np.cos(a) + 0.25, r * np.sin(a)])
        s = (1 / np.sqrt(2)) * np.column_stack([
            -np.abs(parameters[:, 0] + parameters[:, 1]),
            (-parameters[:, 0] + parameters[:, 1]),
        ])
        return torch.Tensor(p + s)
Esempio n. 7
0
File: sre.py Progetto: yyht/lfi
def test_():
    task = "lotka-volterra"
    simulator, prior = simulators.get_simulator_and_prior(task)
    parameter_dim, observation_dim = (
        simulator.parameter_dim,
        simulator.observation_dim,
    )
    true_observation = simulator.get_ground_truth_observation()

    classifier = utils.get_classifier("mlp", parameter_dim, observation_dim)
    ratio_estimator = SRE(
        simulator=simulator,
        true_observation=true_observation,
        classifier=classifier,
        prior=prior,
        num_atoms=-1,
        mcmc_method="slice-np",
        retrain_from_scratch_each_round=False,
    )

    num_rounds, num_simulations_per_round = 10, 1000
    ratio_estimator.run_inference(
        num_rounds=num_rounds,
        num_simulations_per_round=num_simulations_per_round)

    samples = ratio_estimator.sample_posterior(num_samples=2500)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()).reshape(-1),
        lims=[-4, 4],
    )
    figure.savefig(
        os.path.join(utils.get_output_root(), "corner-posterior-ratio.pdf"))

    mmds = ratio_estimator.summary["mmds"]
    if mmds:
        figure, axes = plt.subplots(1, 1)
        axes.plot(
            np.arange(0, num_rounds * num_simulations_per_round,
                      num_simulations_per_round),
            np.array(mmds),
            "-o",
            linewidth=2,
        )
        figure.savefig(os.path.join(utils.get_output_root(), "mmd-ratio.pdf"))
Esempio n. 8
0
    def log_prob(self, observations, parameters):
        """
        Likelihood is proportional to a product of self._num_observations_per_parameter 2D
        Gaussians and so log likelihood can be computed analytically.

        :param observations: torch.Tensor [batch_size, observation_dim]
            Batch of observations.
        :param parameters: torch.Tensor [batch_size, parameter_dim]
            Batch of parameters.
        :return: torch.Tensor [batch_size]
            Log likelihood log p(x | theta) for each item in the batch.
        """

        if isinstance(parameters, torch.Tensor):
            parameters = utils.tensor2numpy(parameters)

        if isinstance(observations, torch.Tensor):
            observations = utils.tensor2numpy(observations)

        if observations.ndim == 1 and parameters.ndim == 1:
            observations, parameters = (
                observations.reshape(1, -1),
                parameters.reshape(1, -1),
            )

        m0, m1, s0, s1, r = self._unpack_params(parameters)
        logdet = np.log(s0) + np.log(s1) + 0.5 * np.log(1.0 - r**2)

        observations = observations.reshape(
            [observations.shape[0], self._num_observations_per_parameter, 2])
        us = np.empty_like(observations)

        us[:, :, 0] = (observations[:, :, 0] - m0) / s0
        us[:, :, 1] = (observations[:, :, 1] - m1 -
                       s1 * r * us[:, :, 0]) / (s1 * np.sqrt(1.0 - r**2))
        us = us.reshape(
            [us.shape[0], 2 * self._num_observations_per_parameter])

        L = (np.sum(scipy.stats.norm.logpdf(us), axis=1) -
             self._num_observations_per_parameter * logdet[:, 0])

        return L
Esempio n. 9
0
    def test_change(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('/')[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, fname) in enumerate(self.testA_loader()):
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            fake_A2B, _, _ = self.genA2B(real_A)

            A2B = RGB2BGR(tensor2numpy(denorm(fake_A2B[0])))

            cv2.imwrite(
                os.path.join(
                    self.result_dir, self.dataset, 'test', 'testA2B',
                    '%s_fake.%s' %
                    (fname.split('.')[0], fname.split('.')[-1])), A2B * 255.0)

        for n, (real_B, fname) in enumerate(self.testB_loader()):
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = to_variable(real_B)
            fake_B2A, _, _ = self.genB2A(real_B)

            B2A = RGB2BGR(tensor2numpy(denorm(fake_B2A[0])))

            cv2.imwrite(
                os.path.join(
                    self.result_dir, self.dataset, 'test', 'testB2A',
                    '%s_fake.%s' %
                    (fname.split('.')[0], fname.split('.')[-1])), B2A * 255.0)
Esempio n. 10
0
File: uniform.py Progetto: yyht/lfi
def _test():
    # prior = MG1Uniform(low=torch.zeros(3), high=torch.Tensor([10, 10, 1 / 3]))
    # uniform = distributions.Uniform(
    #     low=torch.zeros(3), high=torch.Tensor([10, 10, 1 / 3])
    # )
    # x = torch.Tensor([10, 20, 1 / 3]).reshape(1, -1)
    # print(uniform.log_prob(x))
    # print(prior.log_prob(x))
    d = LotkaVolterraOscillating()
    samples = d.sample((1000,))
    utils.plot_hist_marginals(utils.tensor2numpy(samples), lims=[-6, 3])
    plt.show()
Esempio n. 11
0
    def simulate(self, parameters):

        parameters = utils.tensor2numpy(parameters)
        parameters = np.exp(parameters)

        observations = []

        for i, parameter in enumerate(parameters):
            try:
                self._jump_process.reset(self._initial_populations, parameter)
                states = self._jump_process.simulate_for_time(
                    self._dt, self._duration, max_n_steps=self._max_num_steps)
                observations.append(torch.Tensor(states.flatten()))
            except SimTooLongException:
                observations.append(None)
            self.num_total_simulations += 1

        if self._summarize_observations:
            return self._summarizer(observations)

        return observations
Esempio n. 12
0
    def simulate(self, parameters):
        """
        Generates observations for the given batch of parameters.

        :param parameters: torch.Tensor
            Batch of parameters.
        :return: torch.Tensor
            Batch of observations.
        """

        # Run simulator in NumPy.
        if isinstance(parameters, torch.Tensor):
            parameters = utils.tensor2numpy(parameters)

        # If we have a single parameter then view it as a batch of one.
        if parameters.ndim == 1:
            return self.simulate(parameters[np.newaxis, :])[0]

        num_simulations = parameters.shape[0]

        # Keep track of total simulations.
        self.num_total_simulations += num_simulations

        # Run simulator to generate self._num_observations_per_parameter
        # observations from a 2D Gaussian parameterized by the 5 given parameters.
        m0, m1, s0, s1, r = self._unpack_params(parameters)

        us = np.random.randn(num_simulations,
                             self._num_observations_per_parameter, 2)
        observations = np.empty_like(us)

        observations[:, :, 0] = s0 * us[:, :, 0] + m0
        observations[:, :, 1] = (
            s1 * (r * us[:, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, 1]) + m1)

        mean, std = self._get_observation_normalization_parameters()
        return (torch.Tensor(
            observations.reshape(
                [num_simulations, 2 * self._num_observations_per_parameter])) -
                mean.reshape(1, -1)) / std.reshape(1, -1)
Esempio n. 13
0
def viz_to_tb(dataloader, writer, num_classes, display_num=4):
    from collections import defaultdict
    from torchvision.utils import make_grid
    import numpy as np
    from utils import tensor2numpy
    labels_count = {k: 0 for k in range(num_classes)}
    imgs_dict = defaultdict(list)

    dl_iter = iter(dataloader)
    while all([v < display_num for v in labels_count.values()]):
        inputs, labels = next(dl_iter)
        for input_, label in zip(inputs, labels):
            label = int(label)
            if labels_count[label] < display_num:
                imgs_dict[label].append(input_)
            labels_count[label] += 1

    for label, imgs in imgs_dict.items():
        img_grid = make_grid(imgs)
        img_grid = tensor2numpy(img_grid)
        # img_grid = (img_grid * 255).astype(np.uint8)
        writer.add_image(f'example image for label {label}',
                         img_grid,
                         dataformats='HWC')
Esempio n. 14
0
File: mg1.py Progetto: yyht/lfi
 def simulate(self, parameters):
     parameters = utils.tensor2numpy(parameters)
     observations = self._summarizer.calc(self._simulator.sim(parameters))
     return torch.Tensor(observations)
Esempio n. 15
0
File: sre.py Progetto: yyht/lfi
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        classifier,
        num_atoms=-1,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        summary_writer=None,
    ):
        """
        :param simulator: Python object with 'simulate' method which takes a torch.Tensor
        of parameter values, and returns a simulation result for each parameter as a torch.Tensor.
        :param prior: Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor containing the observation x0 for which to
        perform inference on the posterior p(theta | x0).
        :param classifier: Binary classifier in the form of an nets.Module.
        Takes as input (x, theta) pairs and outputs pre-sigmoid activations.
        :param num_atoms: int
            Number of atoms to use for classification.
            If -1, use all other parameters in minibatch.
        :param summary_net: Optional network which may be used to produce feature vectors
        f(x) for high-dimensional observations.
        :param retrain_from_scratch_each_round: Whether to retrain the conditional density
        estimator for the posterior from scratch each round.
        """

        self._simulator = simulator
        self._true_observation = true_observation
        self._classifier = classifier
        self._prior = prior

        assert isinstance(num_atoms,
                          int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        self._mcmc_method = mcmc_method

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net

        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            classifier, prior, true_observation)

        # TODO: decide on Slice Sampling implementation
        target_log_prob = (lambda parameters: self._classifier(
            torch.cat(
                (torch.Tensor(parameters), self._true_observation)).reshape(
                    1, -1)).item() + self._prior.log_prob(
                        torch.Tensor(parameters)).sum().item())
        self._classifier.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._classifier.train()

        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        if retrain_from_scratch_each_round:
            self._untrained_classifier = deepcopy(classifier)
        else:
            self._untrained_classifier = None

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank = [], []

        # Each SRE run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(utils.get_log_root(), "sre", simulator.name,
                                   utils.get_timestamp())
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "mcmc-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
        }
Esempio n. 16
0
    def train(self):
        self.genA2B.train(), self.genB2A.train()
        self.disGA.train(), self.disGB.train()
        self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
            if not len(model_list) == 0:
                model_list.sort()
                start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
                self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
                print(" [*] Load SUCCESS")
                if self.decay_flag and start_iter > (self.iteration // 2):
                    self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \
                        * (start_iter - self.iteration // 2)
                    self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \
                        * (start_iter - self.iteration // 2)

        # training loop
        print('training start !')
        start_time = time.time()

        for step in range(start_iter, self.iteration + 1):
            if self.decay_flag and step > (self.iteration // 2):
                self.G_optim.param_groups[0]['lr'] -= (
                    self.lr / (self.iteration // 2))
                self.D_optim.param_groups[0]['lr'] -= (
                    self.lr / (self.iteration // 2))

            try:
                real_A, _ = trainA_iter.next()  # noqa: F821
            except Exception:
                trainA_iter = iter(self.trainA_loader)
                real_A, _ = trainA_iter.next()

            try:
                real_B, _ = trainB_iter.next()  # noqa: F821
            except Exception:
                trainB_iter = iter(self.trainB_loader)
                real_B, _ = trainB_iter.next()

            real_A, real_B = real_A.to(self.device), real_B.to(self.device)

            # Update D
            self.D_optim.zero_grad()

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(
                self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
            D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(
                self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
            D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(
                self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
            D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(
                self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
            D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(
                self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
            D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(
                self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
            D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(
                self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
            D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(
                self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))

            D_loss_A = self.adv_weight * \
                (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * \
                (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.step()

            # Update G
            self.G_optim.zero_grad()

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(
                fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
            G_ad_cam_loss_GA = self.MSE_loss(
                fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
            G_ad_loss_LA = self.MSE_loss(
                fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
            G_ad_cam_loss_LA = self.MSE_loss(
                fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
            G_ad_loss_GB = self.MSE_loss(
                fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
            G_ad_cam_loss_GB = self.MSE_loss(
                fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
            G_ad_loss_LB = self.MSE_loss(
                fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
            G_ad_cam_loss_LB = self.MSE_loss(
                fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(
                self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
            G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(
                self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

            G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \
                self.cycle_weight * G_recon_loss_A + self.identity_weight * \
                G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \
                self.cycle_weight * G_recon_loss_B + self.identity_weight * \
                G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.step()

            # clip parameter of AdaILN and ILN, applied after optimizer step
            self.genA2B.apply(self.Rho_clipper)
            self.genB2A.apply(self.Rho_clipper)
            msg = "[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time,
                                                                        Discriminator_loss, Generator_loss)
            print(msg)
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval()
                self.disGA.eval(), self.disGB.eval()
                self.disLA.eval(), self.disLB.eval()

                for _ in range(train_sample_num):
                    try:
                        real_A, _ = trainA_iter.next()
                    except Exception:
                        trainA_iter = iter(self.trainA_loader)
                        real_A, _ = trainA_iter.next()

                    try:
                        real_B, _ = trainB_iter.next()
                    except Exception:
                        trainB_iter = iter(self.trainB_loader)
                        real_B, _ = trainB_iter.next()

                    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                               cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                                               cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                               cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                               cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                                               cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                               cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                for _ in range(test_sample_num):
                    try:
                        real_A, _ = testA_iter.next()  # noqa: F821
                    except Exception:
                        testA_iter = iter(self.testA_loader)
                        real_A, _ = testA_iter.next()

                    try:
                        real_B, _ = testB_iter.next()  # noqa: F821
                    except Exception:
                        testB_iter = iter(self.testB_loader)
                        real_B, _ = testB_iter.next()
                    real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                               cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                                               cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                               cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                               cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                                               cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                               cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                                               RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)

                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0)

                self.genA2B.train(), self.genB2A.train()
                self.disGA.train(), self.disGB.train()
                self.disLA.train(), self.disLB.train()

            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)

            if step % 1000 == 0:
                params = {}
                params['genA2B'] = self.genA2B.state_dict()
                params['genB2A'] = self.genB2A.state_dict()
                params['disGA'] = self.disGA.state_dict()
                params['disGB'] = self.disGB.state_dict()
                params['disLA'] = self.disLA.state_dict()
                params['disLB'] = self.disLB.state_dict()
                torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
Esempio n. 17
0
    def test(self):
        model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('_')[-1].split('.')[0])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter)
            print(" [*] Load SUCCESS")
        else:
            print(" [*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader):
            real_A = real_A.to(self.device)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                  cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                                  cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                  cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

            cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader):
            real_B = real_B.to(self.device)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                  cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                                  cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                  cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                                  RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

            cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
Esempio n. 18
0
File: apt.py Progetto: yyht/lfi
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        neural_posterior,
        num_atoms=-1,
        use_combined_loss=False,
        train_with_mcmc=False,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        discard_prior_samples=False,
        summary_writer=None,
    ):
        """
        :param simulator:
            Python object with 'simulate' method which takes a torch.Tensor
            of parameter values, and returns a simulation result for each parameter
            as a torch.Tensor.
        :param prior: Distribution
            Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor [observation_dim] or [1, observation_dim]
            True observation x0 for which to perform inference on the posterior p(theta | x0).
        :param neural_posterior: nets.Module
            Conditional density estimator q(theta | x) with 'log_prob' and 'sample' methods.
        :param num_atoms: int
            Number of atoms to use for classification.
            If -1, use all other parameters in minibatch.
        :param use_combined_loss: bool
            Whether to jointly train prior samples using maximum likelihood.
            Useful to prevent density leaking when using box uniform priors.
        :param train_with_mcmc: bool
            Whether to sample using MCMC instead of i.i.d. sampling at the end of each round
        :param mcmc_method: str
            MCMC method to use if 'train_with_mcmc' is True.
            One of ['slice-numpy', 'hmc', 'nuts'].
        :param summary_net: nets.Module
            Optional network which may be used to produce feature vectors
            f(x) for high-dimensional observations.
        :param retrain_from_scratch_each_round: bool
            Whether to retrain the conditional density estimator for the posterior
            from scratch each round.
        :param discard_prior_samples: bool
            Whether to discard prior samples from round two onwards.
        :param summary_writer: SummaryWriter
            Optionally pass summary writer.
            If None, will create one internally.
        """

        self._simulator = simulator
        self._prior = prior
        self._true_observation = true_observation
        self._neural_posterior = neural_posterior

        assert isinstance(num_atoms, int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        self._use_combined_loss = use_combined_loss

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net

        self._mcmc_method = mcmc_method
        self._train_with_mcmc = train_with_mcmc

        # HMC and NUTS from Pyro.
        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            neural_posterior, prior, self._true_observation
        )

        # Axis-aligned slice sampling implementation in NumPy
        target_log_prob = (
            lambda parameters: self._neural_posterior.log_prob(
                inputs=torch.Tensor(parameters).reshape(1, -1),
                context=self._true_observation.reshape(1, -1),
            ).item()
            if not np.isinf(self._prior.log_prob(torch.Tensor(parameters)).sum().item())
            else -np.inf
        )
        self._neural_posterior.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1,))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._neural_posterior.train()

        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        self._untrained_neural_posterior = deepcopy(neural_posterior)

        self._discard_prior_samples = discard_prior_samples

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank, self._prior_masks = [], [], []

        self._model_bank = []

        self._total_num_generated_examples = 0

        # Each APT run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(
                utils.get_log_root(), "apt", simulator.name, utils.get_timestamp()
            )
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
            "rejection-sampling-acceptance-rates": [],
        }
Esempio n. 19
0
def run(seed):

    assert torch.cuda.is_available()
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    np.random.seed(seed)
    torch.manual_seed(seed)

    # Create training data.
    data_transform = tvtransforms.Compose(
        [tvtransforms.ToTensor(),
         tvtransforms.Lambda(torch.bernoulli)])

    if args.dataset_name == 'mnist':
        dataset = datasets.MNIST(root=os.path.join(utils.get_data_root(),
                                                   'mnist'),
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        test_dataset = datasets.MNIST(root=os.path.join(
            utils.get_data_root(), 'mnist'),
                                      train=False,
                                      download=True,
                                      transform=data_transform)
    elif args.dataset_name == 'fashion-mnist':
        dataset = datasets.FashionMNIST(root=os.path.join(
            utils.get_data_root(), 'fashion-mnist'),
                                        train=True,
                                        download=True,
                                        transform=data_transform)
        test_dataset = datasets.FashionMNIST(root=os.path.join(
            utils.get_data_root(), 'fashion-mnist'),
                                             train=False,
                                             download=True,
                                             transform=data_transform)
    elif args.dataset_name == 'omniglot':
        dataset = data_.OmniglotDataset(split='train',
                                        transform=data_transform)
        test_dataset = data_.OmniglotDataset(split='test',
                                             transform=data_transform)
    elif args.dataset_name == 'emnist':
        rotate = partial(tvF.rotate, angle=-90)
        hflip = tvF.hflip
        data_transform = tvtransforms.Compose([
            tvtransforms.Lambda(rotate),
            tvtransforms.Lambda(hflip),
            tvtransforms.ToTensor(),
            tvtransforms.Lambda(torch.bernoulli)
        ])
        dataset = datasets.EMNIST(root=os.path.join(utils.get_data_root(),
                                                    'emnist'),
                                  split='letters',
                                  train=True,
                                  transform=data_transform,
                                  download=True)
        test_dataset = datasets.EMNIST(root=os.path.join(
            utils.get_data_root(), 'emnist'),
                                       split='letters',
                                       train=False,
                                       transform=data_transform,
                                       download=True)
    else:
        raise ValueError

    if args.dataset_name == 'omniglot':
        split = -1345
    elif args.dataset_name == 'emnist':
        split = -20000
    else:
        split = -10000
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[:split], indices[split:]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    train_loader = data.DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=4 if args.dataset_name == 'emnist' else 0)
    train_generator = data_.batch_generator(train_loader)
    val_loader = data.DataLoader(dataset=dataset,
                                 batch_size=1024,
                                 sampler=val_sampler,
                                 shuffle=False,
                                 drop_last=False)
    val_batch = next(iter(val_loader))[0]
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
    )

    # from matplotlib import pyplot as plt
    # from experiments import cutils
    # from torchvision.utils import make_grid
    # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    # cutils.gridimshow(make_grid(val_batch[:64], nrow=8), ax)
    # plt.show()
    # quit()

    def create_linear_transform():
        if args.linear_type == 'lu':
            return transforms.CompositeTransform([
                transforms.RandomPermutation(args.latent_features),
                transforms.LULinear(args.latent_features, identity_init=True)
            ])
        elif args.linear_type == 'svd':
            return transforms.SVDLinear(args.latent_features,
                                        num_householder=4,
                                        identity_init=True)
        elif args.linear_type == 'perm':
            return transforms.RandomPermutation(args.latent_features)
        else:
            raise ValueError

    def create_base_transform(i, context_features=None):
        if args.prior_type == 'affine-coupling':
            return transforms.AffineCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm))
        elif args.prior_type == 'rq-coupling':
            return transforms.PiecewiseRationalQuadraticCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm),
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                apply_unconditional_transform=args.
                apply_unconditional_transform,
            )
        elif args.prior_type == 'affine-autoregressive':
            return transforms.MaskedAffineAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        elif args.prior_type == 'rq-autoregressive':
            return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        else:
            raise ValueError

    # ---------------
    # prior
    # ---------------
    def create_prior():
        if args.prior_type == 'standard-normal':
            prior = distributions_.StandardNormal((args.latent_features, ))

        else:
            distribution = distributions_.StandardNormal(
                (args.latent_features, ))
            transform = transforms.CompositeTransform([
                transforms.CompositeTransform(
                    [create_linear_transform(),
                     create_base_transform(i)])
                for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            prior = flows.Flow(transform, distribution)

        return prior

    # ---------------
    # inputs encoder
    # ---------------
    def create_inputs_encoder():
        if args.approximate_posterior_type == 'diagonal-normal':
            inputs_encoder = None
        else:
            inputs_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
        return inputs_encoder

    # ---------------
    # approximate posterior
    # ---------------
    def create_approximate_posterior():
        if args.approximate_posterior_type == 'diagonal-normal':
            context_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
            approximate_posterior = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

        else:
            context_encoder = nn.Linear(args.context_features,
                                        2 * args.latent_features)
            distribution = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

            transform = transforms.CompositeTransform([
                transforms.CompositeTransform([
                    create_linear_transform(),
                    create_base_transform(
                        i, context_features=args.context_features)
                ]) for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            approximate_posterior = flows.Flow(
                transforms.InverseTransform(transform), distribution)

        return approximate_posterior

    # ---------------
    # likelihood
    # ---------------
    def create_likelihood():
        latent_decoder = nn_.ConvDecoder(
            latent_features=args.latent_features,
            channels_multiplier=16,
            dropout_probability=args.dropout_probability_encoder_decoder)

        likelihood = distributions_.ConditionalIndependentBernoulli(
            shape=[1, 28, 28], context_encoder=latent_decoder)

        return likelihood

    prior = create_prior()
    approximate_posterior = create_approximate_posterior()
    likelihood = create_likelihood()
    inputs_encoder = create_inputs_encoder()

    model = vae.VariationalAutoencoder(
        prior=prior,
        approximate_posterior=approximate_posterior,
        likelihood=likelihood,
        inputs_encoder=inputs_encoder)

    # with torch.no_grad():
    #     # elbo = model.stochastic_elbo(val_batch[:16].to(device)).mean()
    #     # print(elbo)
    #     elbo = model.stochastic_elbo(val_batch[:16].to(device), num_samples=100).mean()
    #     print(elbo)
    #     log_prob = model.log_prob_lower_bound(val_batch[:16].to(device), num_samples=1200).mean()
    #     print(log_prob)
    # quit()

    n_params = utils.get_num_parameters(model)
    print('There are {} trainable parameters in this model.'.format(n_params))

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=args.num_training_steps, eta_min=0)

    def get_kl_multiplier(step):
        if args.kl_multiplier_schedule == 'constant':
            return args.kl_multiplier_initial
        elif args.kl_multiplier_schedule == 'linear':
            multiplier = min(
                step / (args.num_training_steps * args.kl_warmup_fraction), 1.)
            return args.kl_multiplier_initial * (1. + multiplier)

    # create summary writer and write to log directory
    timestamp = cutils.get_timestamp()
    if cutils.on_cluster():
        timestamp += '||{}'.format(os.environ['SLURM_JOB_ID'])
    log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
    while True:
        try:
            writer = SummaryWriter(log_dir=log_dir, max_queue=20)
            break
        except FileExistsError:
            sleep(5)
    filename = os.path.join(log_dir, 'config.json')
    with open(filename, 'w') as file:
        json.dump(vars(args), file)

    best_val_elbo = -np.inf
    tbar = tqdm(range(args.num_training_steps))
    for step in tbar:
        model.train()
        optimizer.zero_grad()
        scheduler.step(step)

        batch = next(train_generator)[0].to(device)
        elbo = model.stochastic_elbo(batch,
                                     kl_multiplier=get_kl_multiplier(step))
        loss = -torch.mean(elbo)
        loss.backward()
        optimizer.step()

        if (step + 1) % args.monitor_interval == 0:
            model.eval()
            with torch.no_grad():
                elbo = model.stochastic_elbo(val_batch.to(device))
                mean_val_elbo = elbo.mean()

            if mean_val_elbo > best_val_elbo:
                best_val_elbo = mean_val_elbo
                path = os.path.join(
                    cutils.get_checkpoint_root(),
                    '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
                torch.save(model.state_dict(), path)

            writer.add_scalar(tag='val-elbo',
                              scalar_value=mean_val_elbo,
                              global_step=step)

            writer.add_scalar(tag='best-val-elbo',
                              scalar_value=best_val_elbo,
                              global_step=step)

            with torch.no_grad():
                samples = model.sample(64)
            fig, ax = plt.subplots(figsize=(10, 10))
            cutils.gridimshow(make_grid(samples.view(64, 1, 28, 28), nrow=8),
                              ax)
            writer.add_figure(tag='vae-samples', figure=fig, global_step=step)
            plt.close()

    # load best val model
    path = os.path.join(
        cutils.get_checkpoint_root(),
        '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
    model.load_state_dict(torch.load(path))
    model.eval()

    np.random.seed(5)
    torch.manual_seed(5)

    # compute elbo on test set
    with torch.no_grad():
        elbo = torch.Tensor([])
        log_prob_lower_bound = torch.Tensor([])
        for batch in tqdm(test_loader):
            elbo_ = model.stochastic_elbo(batch[0].to(device))
            elbo = torch.cat([elbo, elbo_])
            log_prob_lower_bound_ = model.log_prob_lower_bound(
                batch[0].to(device), num_samples=1000)
            log_prob_lower_bound = torch.cat(
                [log_prob_lower_bound, log_prob_lower_bound_])
    path = os.path.join(
        log_dir, '{}-prior-{}-posterior-{}-elbo.npy'.format(
            args.dataset_name, args.prior_type,
            args.approximate_posterior_type))
    np.save(path, utils.tensor2numpy(elbo))
    path = os.path.join(
        log_dir, '{}-prior-{}-posterior-{}-log-prob-lower-bound.npy'.format(
            args.dataset_name, args.prior_type,
            args.approximate_posterior_type))
    np.save(path, utils.tensor2numpy(log_prob_lower_bound))

    # save elbo and log prob lower bound
    mean_elbo = elbo.mean()
    std_elbo = elbo.std()
    mean_log_prob_lower_bound = log_prob_lower_bound.mean()
    std_log_prob_lower_bound = log_prob_lower_bound.std()
    s = 'ELBO: {:.2f} +- {:.2f}, LOG PROB LOWER BOUND: {:.2f} +- {:.2f}'.format(
        mean_elbo.item(), 2 * std_elbo.item() / np.sqrt(len(test_dataset)),
        mean_log_prob_lower_bound.item(),
        2 * std_log_prob_lower_bound.item() / np.sqrt(len(test_dataset)))
    filename = os.path.join(log_dir, 'test-results.txt')
    with open(filename, 'w') as file:
        file.write(s)
Esempio n. 20
0
File: snl.py Progetto: yyht/lfi
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        neural_likelihood,
        mcmc_method="slice-np",
        summary_writer=None,
    ):
        """

        :param simulator: Python object with 'simulate' method which takes a torch.Tensor
        of parameter values, and returns a simulation result for each parameter as a torch.Tensor.
        :param prior: Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor containing the observation x0 for which to
        perform inference on the posterior p(theta | x0).
        :param neural_likelihood: Conditional density estimator q(x | theta) in the form of an
        nets.Module. Must have 'log_prob' and 'sample' methods.
        :param mcmc_method: MCMC method to use for posterior sampling. Must be one of
        ['slice', 'hmc', 'nuts'].
        """

        self._simulator = simulator
        self._prior = prior
        self._true_observation = true_observation
        self._neural_likelihood = neural_likelihood
        self._mcmc_method = mcmc_method

        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            neural_likelihood=self._neural_likelihood,
            prior=self._prior,
            true_observation=self._true_observation,
        )

        # TODO: decide on Slice Sampling implementation
        target_log_prob = (lambda parameters: self._neural_likelihood.log_prob(
            inputs=self._true_observation.reshape(1, -1),
            context=torch.Tensor(parameters).reshape(1, -1),
        ).item() + self._prior.log_prob(torch.Tensor(parameters)).sum().item())
        self._neural_likelihood.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._neural_likelihood.train()

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank = [], []

        # Each SNL run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(utils.get_log_root(), "snl", simulator.name,
                                   utils.get_timestamp())
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "mcmc-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
        }
Esempio n. 21
0
"""Upload customed cnn model"""
cnn = CNN(256, 256, 3, 101)
cnn.load_weights('weights/custom/cnn_plus.h5')
plot_model(cnn, to_file='./model.png', show_shapes=True, show_layer_names=True)


train_model(2, 'cnn_plus', cnn, srgan)

#filepath="./cnn_weights.h5"
#checkpoint = ModelCheckpoint(filepath, monitor='accuracy', verbose=1, save_best_only=True, mode='max')
#callbacks_list = [checkpoint]

"""Prepare and train on a batch of data and labels, 10 iterations"""
for i in range(2):
    train_set = devide(24, 2, 2)
    X = tensor2numpy('./data/', train_set, srgan)
    x = [X[i] for i in X.keys()]
    train = np.array(x, dtype = "float64")
    y = create_onehot(X)
    history = cnn.fit(train, y, batch_size=32, epochs=5, callbacks=callbacks_list, validation_split=0.2)
    # Plot training & validation accuracy values
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.show()

"""Upload, use transfer learning"""
VGG=VGG19(input_shape=(224,224,3),include_top=False,weights='imagenet')
Esempio n. 22
0
    def train(self):
        epochs = 1000
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()
        print('training start !')
        start_time = time.time()
        '''加载预训练模型'''
        if self.pretrain:
            str_genA2B = "Parameters/genA2B%03d.pdparams" % (self.start - 1)
            str_genB2A = "Parameters/genB2A%03d.pdparams" % (self.start - 1)
            str_disGA = "Parameters/disGA%03d.pdparams" % (self.start - 1)
            str_disGB = "Parameters/disGB%03d.pdparams" % (self.start - 1)
            str_disLA = "Parameters/disLA%03d.pdparams" % (self.start - 1)
            str_disLB = "Parameters/disLB%03d.pdparams" % (self.start - 1)
            genA2B_para, gen_A2B_opt = fluid.load_dygraph(str_genA2B)
            genB2A_para, gen_B2A_opt = fluid.load_dygraph(str_genB2A)
            disGA_para, disGA_opt = fluid.load_dygraph(str_disGA)
            disGB_para, disGB_opt = fluid.load_dygraph(str_disGB)
            disLA_para, disLA_opt = fluid.load_dygraph(str_disLA)
            disLB_para, disLB_opt = fluid.load_dygraph(str_disLB)
            self.genA2B.load_dict(genA2B_para)
            self.genB2A.load_dict(genB2A_para)
            self.disGA.load_dict(disGA_para)
            self.disGB.load_dict(disGB_para)
            self.disLA.load_dict(disLA_para)
            self.disLB.load_dict(disLB_para)
        for epoch in range(self.start, epochs):
            for block_id, data in enumerate(self.train_reader()):
                real_A = np.array([x[0] for x in data], np.float32)
                real_B = np.array([x[1] for x in data], np.float32)
                real_A = totensor(real_A, block_id, 'train')
                real_B = totensor(real_B, block_id, 'train')

                # Update D

                fake_A2B, _, _ = self.genA2B(real_A)
                fake_B2A, _, _ = self.genB2A(real_B)

                real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
                real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
                real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
                real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                D_ad_loss_GA = mse_loss(1, real_GA_logit) + mse_loss(
                    0, fake_GA_logit)
                D_ad_cam_loss_GA = mse_loss(1, real_GA_cam_logit) + mse_loss(
                    0, fake_GA_cam_logit)

                D_ad_loss_LA = mse_loss(1, real_LA_logit) + mse_loss(
                    0, fake_LA_logit)
                D_ad_cam_loss_LA = mse_loss(1, real_LA_cam_logit) + mse_loss(
                    0, fake_LA_cam_logit)

                D_ad_loss_GB = mse_loss(1, real_GB_logit) + mse_loss(
                    0, fake_GB_logit)
                D_ad_cam_loss_GB = mse_loss(1, real_GB_cam_logit) + mse_loss(
                    0, fake_GB_cam_logit)

                D_ad_loss_LB = mse_loss(1, real_LB_logit) + mse_loss(
                    0, fake_LB_logit)
                D_ad_cam_loss_LB = mse_loss(1, real_LB_cam_logit) + mse_loss(
                    0, fake_LB_cam_logit)

                D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                              D_ad_loss_LA + D_ad_cam_loss_LA)
                D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                              D_ad_loss_LB + D_ad_cam_loss_LB)

                Discriminator_loss = D_loss_A + D_loss_B
                Discriminator_loss.backward()
                self.D_opt.minimize(Discriminator_loss)
                self.disGA.clear_gradients(), self.disGB.clear_gradients(
                ), self.disLA.clear_gradients(), self.disLB.clear_gradients()

                # Update G

                fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
                fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
                print("fake_A2B.shape:", fake_A2B.shape)
                fake_A2B2A, _, _ = self.genB2A(fake_A2B)
                fake_B2A2B, _, _ = self.genA2B(fake_B2A)

                fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
                fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

                fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
                fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
                fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
                fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

                G_ad_loss_GA = mse_loss(1, fake_GA_logit)
                G_ad_cam_loss_GA = mse_loss(1, fake_GA_cam_logit)

                G_ad_loss_LA = mse_loss(1, fake_LA_logit)
                G_ad_cam_loss_LA = mse_loss(1, fake_LA_cam_logit)

                G_ad_loss_GB = mse_loss(1, fake_GB_logit)
                G_ad_cam_loss_GB = mse_loss(1, fake_GB_cam_logit)

                G_ad_loss_LB = mse_loss(1, fake_LB_logit)
                G_ad_cam_loss_LB = mse_loss(1, fake_LB_cam_logit)

                G_recon_loss_A = self.L1loss(fake_A2B2A, real_A)
                G_recon_loss_B = self.L1loss(fake_B2A2B, real_B)

                G_identity_loss_A = self.L1loss(fake_A2A, real_A)
                G_identity_loss_B = self.L1loss(fake_B2B, real_B)

                G_cam_loss_A = bce_loss(1, fake_B2A_cam_logit) + bce_loss(
                    0, fake_A2A_cam_logit)
                G_cam_loss_B = bce_loss(1, fake_A2B_cam_logit) + bce_loss(
                    0, fake_B2B_cam_logit)

                G_loss_A = self.adv_weight * (
                    G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                    G_ad_cam_loss_LA
                ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
                G_loss_B = self.adv_weight * (
                    G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                    G_ad_cam_loss_LB
                ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

                Generator_loss = G_loss_A + G_loss_B
                Generator_loss.backward()
                self.G_opt.minimize(Generator_loss)
                self.genA2B.clear_gradients(), self.genB2A.clear_gradients()

                print("[%5d/%5d] time: %4.4f d_loss: %.5f, g_loss: %.5f" %
                      (epoch, block_id, time.time() - start_time,
                       Discriminator_loss.numpy(), Generator_loss.numpy()))
                print("G_loss_A: %.5f G_loss_B: %.5f" %
                      (G_loss_A.numpy(), G_loss_B.numpy()))
                print("G_ad_loss_GA: %.5f   G_ad_loss_GB: %.5f" %
                      (G_ad_loss_GA.numpy(), G_ad_loss_GB.numpy()))
                print("G_ad_loss_LA: %.5f   G_ad_loss_LB: %.5f" %
                      (G_ad_loss_LA.numpy(), G_ad_loss_LB.numpy()))
                print("G_cam_loss_A:%.5f  G_cam_loss_B:%.5f" %
                      (G_cam_loss_A.numpy(), G_cam_loss_B.numpy()))
                print("G_recon_loss_A:%.5f  G_recon_loss_B:%.5f" %
                      (G_recon_loss_A.numpy(), G_recon_loss_B.numpy()))
                print("G_identity_loss_A:%.5f  G_identity_loss_B:%.5f" %
                      (G_identity_loss_B.numpy(), G_identity_loss_B.numpy()))

                if epoch % 2 == 1 and block_id % self.print_freq == 0:

                    A2B = np.zeros((self.img_size * 7, 0, 3))
                    # B2A = np.zeros((self.img_size * 7, 0, 3))
                    for eval_id, eval_data in enumerate(self.test_reader()):
                        if eval_id == 10:
                            break
                        real_A = np.array([x[0] for x in eval_data],
                                          np.float32)
                        real_B = np.array([x[1] for x in eval_data],
                                          np.float32)
                        real_A = totensor(real_A, eval_id, 'eval')
                        real_B = totensor(real_B, eval_id, 'eval')

                        fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                        fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                        fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                            fake_A2B)
                        fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                            fake_B2A)

                        fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                        fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                        a = tensor2numpy(denorm(real_A[0]))
                        b = cam(tensor2numpy(fake_A2A_heatmap[0]),
                                self.img_size)
                        c = tensor2numpy(denorm(fake_A2A[0]))
                        d = cam(tensor2numpy(fake_A2B_heatmap[0]),
                                self.img_size)
                        e = tensor2numpy(denorm(fake_A2B[0]))
                        f = cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                self.img_size)
                        g = tensor2numpy(denorm(fake_A2B2A[0]))
                        A2B = np.concatenate((A2B, (np.concatenate(
                            (a, b, c, d, e, f, g)) * 255).astype(np.uint8)),
                                             1).astype(np.uint8)
                    A2B = Image.fromarray(A2B)
                    A2B.save('Images/%d_%d.png' % (epoch, block_id))
                    self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                    ), self.disGB.train(), self.disLA.train(
                    ), self.disLB.train()
            if epoch % 4 == 0:
                fluid.save_dygraph(self.genA2B.state_dict(),
                                   "Parameters/genA2B%03d" % (epoch))
                fluid.save_dygraph(self.genB2A.state_dict(),
                                   "Parameters/genB2A%03d" % (epoch))
                fluid.save_dygraph(self.disGA.state_dict(),
                                   "Parameters/disGA%03d" % (epoch))
                fluid.save_dygraph(self.disGB.state_dict(),
                                   "Parameters/disGB%03d" % (epoch))
                fluid.save_dygraph(self.disLA.state_dict(),
                                   "Parameters/disLA%03d" % (epoch))
                fluid.save_dygraph(self.disLB.state_dict(),
                                   "Parameters/disLB%03d" % (epoch))
    flow = create_flow(args.flow_type)
    flow.load_state_dict(torch.load(args.input_ckpt))

    flow.eval()

    estimated_cov = np.cov(train_data, rowvar=False)
    _, pca_v = pca(estimated_cov)
    c = eval_data @ pca_v[:, 0]

    fig, ax = plt.subplots(1, 1, figsize=(2, 2))

    ax.set_xlim([-4, 4])
    ax.set_ylim([-4, 4])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel('$z_1$')
    ax.set_ylabel('$z_2$')

    with torch.no_grad():
        proj = flow.transform_to_noise(torch.FloatTensor(eval_data))
        proj = tensor2numpy(proj)
    s = ax.scatter(proj[:, 0], proj[:, 1], c=c.flat, alpha=0.3)
    s.set_rasterized(True)

    if args.output_pdf:
        fig.tight_layout()
        fig.savefig(args.output_pdf, bbox_inches='tight', dpi=150)
    else:
        plt.show()
Esempio n. 24
0
    def test(self):
        model_list = os.listdir(
            os.path.join(self.result_dir, self.dataset, 'model'))
        if not len(model_list) == 0:

            model_list.sort()
            iter = int(model_list[-1])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                      iter)
            print("[*] Load SUCCESS")
        else:
            print("[*] Load FAILURE")
            return

        self.genA2B.eval(), self.genB2A.eval()
        for n, (real_A, _) in enumerate(self.testA_loader()):

            real_A = np.array([real_A.reshape(3, 256, 256)]).astype("float32")

            real_A = to_variable(real_A)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)

            A2B = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                 cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                 cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                 cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'A2B_%d.png' % (n + 1)), A2B * 255.0)

        for n, (real_B, _) in enumerate(self.testB_loader()):

            real_B = np.array([real_B.reshape(3, 256, 256)]).astype("float32")

            real_B = to_variable(real_B)

            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            B2A = np.concatenate(
                (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                 cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                 cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                 cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
                 RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)

            cv2.imwrite(
                os.path.join(self.result_dir, self.dataset, 'test',
                             'B2A_%d.png' % (n + 1)), B2A * 255.0)
Esempio n. 25
0
        tbar.set_description(s)

        summaries = {'loss': loss.detach()}
        for summary, value in summaries.items():
            writer.add_scalar(tag=summary,
                              scalar_value=value,
                              global_step=step)

    if (step + 1) % args.visualize_interval == 0:
        flow.eval()
        log_density_np = []
        for batch in grid_loader:
            batch = batch.to(device)
            _, log_density = flow.log_prob(batch)
            log_density_np = np.concatenate(
                (log_density_np, utils.tensor2numpy(log_density)))

        figure, axes = plt.subplots(1,
                                    3,
                                    figsize=(7.5, 2.5),
                                    sharex=True,
                                    sharey=True)

        cmap = cm.magma
        axes[0].hist2d(utils.tensor2numpy(train_dataset.data[:, 0]),
                       utils.tensor2numpy(train_dataset.data[:, 1]),
                       range=bounds,
                       bins=512,
                       cmap=cmap,
                       rasterized=False)
        axes[0].set_xlim(bounds[0])
        for summary, value in summaries.items():
            writer.add_scalar(tag=summary, scalar_value=value, global_step=step)

    if (step + 1) % args.visualize_interval == 0:
        # Plotting
        aem.eval()
        aem.set_n_proposal_samples_per_input_validation(
            args.n_proposal_samples_per_input_validation)
        log_density_np = []
        log_proposal_density_np = []
        for batch in grid_loader:
            batch = batch.to(device)
            log_density, log_proposal_density, unnormalized_log_density, log_normalizer = aem(
                batch)
            log_density_np = np.concatenate((
                log_density_np, utils.tensor2numpy(log_density)
            ))
            log_proposal_density_np = np.concatenate((
                log_proposal_density_np, utils.tensor2numpy(log_proposal_density)
            ))

        fig, axs = plt.subplots(1, 3, figsize=(7.5, 2.5))

        axs[0].hist2d(train_dataset.data[:, 0], train_dataset.data[:, 1],
                      range=bounds, bins=512, cmap=cm.viridis, rasterized=False)
        axs[0].set_xticks([])
        axs[0].set_yticks([])

        axs[1].pcolormesh(grid_dataset.X, grid_dataset.Y,
                          np.exp(log_proposal_density_np).reshape(grid_dataset.X.shape))
        axs[1].set_xlim(bounds[0])
Esempio n. 27
0
def totensor(imgs):
    imgs = fluid.dygraph.to_variable(imgs)
    imgs = imgs / 255.
    imgs = fluid.layers.transpose(imgs, (0,3,1,2))
    imgs = fluid.layers.image_resize(imgs, (256,256))
    imgs = (imgs - 0.5) / 0.5
    return imgs


if __name__ == "__main__":
    gl._init()
    gl.set_value('rho',0)
    real_paths = os.listdir(real_source)
    with fluid.dygraph.guard():
        genA2B = ResnetGenerator(in_channels=3, out_channels=3, ngf= 64, n_blocks=4)
        genA2B_para, gen_A2B_opt = fluid.load_dygraph("Parameters/genA2B124.pdparams")
        genA2B.load_dict(genA2B_para)
        count = 0
        for real_image_path in real_paths:
            real_image_path = os.path.join(real_source, real_image_path)
            img = np.array(Image.open(real_image_path).convert("RGB")).astype(np.float32)
            img = img[np.newaxis,:,:,:]
            img = totensor(img)
            fakeA2B,_,_ = genA2B(img)
 
            a = (tensor2numpy(denorm(fakeA2B[0]))*255).astype(np.uint8)
            a = Image.fromarray(a)
            save_path = os.path.join(fake_path, "%04d"%(count)+"_fake.png")
            count += 1
            a.save(save_path)
Esempio n. 28
0
File: sre.py Progetto: yyht/lfi
    def _summarize(self, round_):

        # Update summaries.
        try:
            mmd = utils.unbiased_mmd_squared(
                self._parameter_bank[-1],
                self._simulator.get_ground_truth_posterior_samples(
                    num_samples=1000),
            )
            self._summary["mmds"].append(mmd.item())
        except:
            pass

        median_observation_distance = torch.median(
            torch.sqrt(
                torch.sum(
                    (self._observation_bank[-1] -
                     self._true_observation.reshape(1, -1))**2,
                    dim=-1,
                )))
        self._summary["median-observation-distances"].append(
            median_observation_distance.item())

        negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval(
            samples=self._parameter_bank[-1],
            query=self._simulator.get_ground_truth_parameters().reshape(1, -1),
        )
        self._summary["negative-log-probs-true-parameters"].append(
            negative_log_prob_true_parameters.item())

        # Plot most recently sampled parameters in TensorBoard.
        parameters = utils.tensor2numpy(self._parameter_bank[-1])
        figure = utils.plot_hist_marginals(
            data=parameters,
            ground_truth=utils.tensor2numpy(
                self._simulator.get_ground_truth_parameters()).reshape(-1),
            lims=self._simulator.parameter_plotting_limits,
        )
        self._summary_writer.add_figure(tag="posterior-samples",
                                        figure=figure,
                                        global_step=round_ + 1)

        self._summary_writer.add_scalar(
            tag="epochs-trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best-validation-log-prob",
            scalar_value=self._summary["best-validation-log-probs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="median-observation-distance",
            scalar_value=self._summary["median-observation-distances"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="negative-log-prob-true-parameters",
            scalar_value=self._summary["negative-log-probs-true-parameters"]
            [-1],
            global_step=round_ + 1,
        )

        if self._summary["mmds"]:
            self._summary_writer.add_scalar(
                tag="mmd",
                scalar_value=self._summary["mmds"][-1],
                global_step=round_ + 1,
            )

        self._summary_writer.flush()
Esempio n. 29
0
                              scalar_value=value,
                              global_step=step)

# load best val model
path = os.path.join(cutils.get_checkpoint_root(),
                    '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
flow.load_state_dict(torch.load(path))
flow.eval()

# calculate log-likelihood on test set
with torch.no_grad():
    log_likelihood = torch.Tensor([])
    for batch in tqdm(test_loader):
        _, log_density = flow.log_prob(batch.to(device))
        log_likelihood = torch.cat([log_likelihood, log_density])
path = os.path.join(
    log_dir, '{}-{}-log-likelihood.npy'.format(args.dataset_name,
                                               args.base_transform_type))
np.save(path, utils.tensor2numpy(log_likelihood))
mean_log_likelihood = log_likelihood.mean()
std_log_likelihood = log_likelihood.std()

# save log-likelihood
s = 'Final score for {}: {:.2f} +- {:.2f}'.format(
    args.dataset_name.capitalize(), mean_log_likelihood.item(),
    2 * std_log_likelihood.item() / np.sqrt(len(test_dataset)))
print(s)
filename = os.path.join(log_dir, 'test-results.txt')
with open(filename, 'w') as file:
    file.write(s)
Esempio n. 30
0
    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = os.listdir(
                os.path.join(self.result_dir, self.dataset, 'model'))
            if not len(model_list) == 0:
                model_list.sort()
                iter = int(model_list[-1])
                print("[*]load %d" % (iter))
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          iter)
                print("[*] Load SUCCESS")

        # training loop
        print('training start !')
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            real_A = next(self.trainA_loader)
            real_B = next(self.trainB_loader)
            real_A = np.array([real_A[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_B = np.array([real_B[0].reshape(3, 256,
                                                 256)]).astype("float32")
            real_A = to_variable(real_A)
            real_B = to_variable(real_B)
            # Update D

            fake_A2B, _, _ = self.genA2B(real_A)
            fake_B2A, _, _ = self.genB2A(real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss(
                    fake_GA_logit, zeros_like(fake_GA_logit))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                ones_like(real_GA_cam_logit)) + self.MSE_loss(
                    fake_GA_cam_logit, zeros_like(fake_GA_cam_logit))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss(
                    fake_LA_logit, zeros_like(fake_LA_logit))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                ones_like(real_LA_cam_logit)) + self.MSE_loss(
                    fake_LA_cam_logit, zeros_like(fake_LA_cam_logit))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss(
                    fake_GB_logit, zeros_like(fake_GB_logit))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                ones_like(real_GB_cam_logit)) + self.MSE_loss(
                    fake_GB_cam_logit, zeros_like(fake_GB_cam_logit))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss(
                    fake_LB_logit, zeros_like(fake_LB_logit))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                ones_like(real_LB_cam_logit)) + self.MSE_loss(
                    fake_LB_cam_logit, zeros_like(fake_LB_cam_logit))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.minimize(Discriminator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.D_optim.clear_gradients()

            # Update G

            fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
            fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

            fake_A2B2A, _, _ = self.genB2A(fake_A2B)
            fake_B2A2B, _, _ = self.genA2B(fake_B2A)

            fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
            fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

            G_ad_loss_GA = self.MSE_loss(fake_GA_logit,
                                         ones_like(fake_GA_logit))
            G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit,
                                             ones_like(fake_GA_cam_logit))
            G_ad_loss_LA = self.MSE_loss(fake_LA_logit,
                                         ones_like(fake_LA_logit))
            G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit,
                                             ones_like(fake_LA_cam_logit))
            G_ad_loss_GB = self.MSE_loss(fake_GB_logit,
                                         ones_like(fake_GB_logit))
            G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit,
                                             ones_like(fake_GB_cam_logit))
            G_ad_loss_LB = self.MSE_loss(fake_LB_logit,
                                         ones_like(fake_LB_logit))
            G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit,
                                             ones_like(fake_LB_cam_logit))

            G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
            G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

            G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
            G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                ones_like(fake_B2A_cam_logit)) + self.BCE_loss(
                    fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                ones_like(fake_A2B_cam_logit)) + self.BCE_loss(
                    fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.minimize(Generator_loss)
            self.genB2A.clear_gradients()
            self.genA2B.clear_gradients()
            self.disGA.clear_gradients()
            self.disLA.clear_gradients()
            self.disGB.clear_gradients()
            self.disLB.clear_gradients()
            self.G_optim.clear_gradients()

            self.Rho_clipper(self.genA2B)
            self.Rho_clipper(self.genB2A)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))

            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    real_A = next(self.trainA_loader)
                    real_B = next(self.trainB_loader)
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                for _ in range(test_sample_num):
                    real_A = next(self.testA_loader())
                    real_B = next(self.testB_loader())
                    real_A = np.array([real_A[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_B = np.array([real_B[0].reshape(3, 256, 256)
                                       ]).astype("float32")
                    real_A = to_variable(real_A)
                    real_B = to_variable(real_B)

                    fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
                    fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

                    fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
                    fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

                    fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
                    fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))),
                             0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))),
                             0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()
            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genA2B"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/genB2A"))
                fluid.save_dygraph(
                    self.disGA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGA"))
                fluid.save_dygraph(
                    self.disGB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disGB"))
                fluid.save_dygraph(
                    self.disLA.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLA"))
                fluid.save_dygraph(
                    self.disLB.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/disLB"))
                fluid.save_dygraph(
                    self.D_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.G_optim.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))
                fluid.save_dygraph(
                    self.genA2B.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/D_optim"))
                fluid.save_dygraph(
                    self.genB2A.state_dict(),
                    os.path.join(self.result_dir,
                                 self.dataset + "/latest/new/G_optim"))