def main(): task = "mg1" simulator, prior = simulators.get_simulator_and_prior(task) parameter_dim, observation_dim = ( simulator.parameter_dim, simulator.observation_dim, ) true_observation = simulator.get_ground_truth_observation() neural_likelihood = utils.get_neural_likelihood("maf", parameter_dim, observation_dim) snl = SNL( simulator=simulator, true_observation=true_observation, prior=prior, neural_likelihood=neural_likelihood, mcmc_method="slice-np", ) num_rounds, num_simulations_per_round = 10, 1000 snl.run_inference(num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round) samples = snl.sample_posterior(1000) samples = utils.tensor2numpy(samples) figure = utils.plot_hist_marginals( data=samples, ground_truth=utils.tensor2numpy( simulator.get_ground_truth_parameters()).reshape(-1), lims=simulator.parameter_plotting_limits, ) figure.savefig("./corner-posterior-snl.pdf")
def test_(): task = "nonlinear-gaussian" simulator, prior = simulators.get_simulator_and_prior(task) parameter_dim, observation_dim = ( simulator.parameter_dim, simulator.observation_dim, ) true_observation = simulator.get_ground_truth_observation() neural_posterior = utils.get_neural_posterior( "maf", parameter_dim, observation_dim, simulator ) apt = APT( simulator=simulator, true_observation=true_observation, prior=prior, neural_posterior=neural_posterior, num_atoms=-1, use_combined_loss=False, train_with_mcmc=False, mcmc_method="slice-np", summary_net=None, retrain_from_scratch_each_round=False, discard_prior_samples=False, ) num_rounds, num_simulations_per_round = 20, 1000 apt.run_inference( num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round ) samples = apt.sample_posterior(2500) samples = utils.tensor2numpy(samples) figure = utils.plot_hist_marginals( data=samples, ground_truth=utils.tensor2numpy( simulator.get_ground_truth_parameters() ).reshape(-1), lims=simulator.parameter_plotting_limits, ) figure.savefig(os.path.join(utils.get_output_root(), "corner-posterior-apt.pdf")) samples = apt.sample_posterior_mcmc(num_samples=1000) samples = utils.tensor2numpy(samples) figure = utils.plot_hist_marginals( data=samples, ground_truth=utils.tensor2numpy( simulator.get_ground_truth_parameters() ).reshape(-1), lims=simulator.parameter_plotting_limits, ) figure.savefig( os.path.join(utils.get_output_root(), "corner-posterior-apt-mcmc.pdf") )
def test_(): # if torch.cuda.is_available(): # device = torch.device("cuda") # torch.set_default_tensor_type("torch.cuda.FloatTensor") # else: # input("CUDA not available, do you wish to continue?") # device = torch.device("cpu") # torch.set_default_tensor_type("torch.FloatTensor") loc = torch.Tensor([0, 0]) covariance_matrix = torch.Tensor([[1, 0.99], [0.99, 1]]) likelihood = distributions.MultivariateNormal( loc=loc, covariance_matrix=covariance_matrix) bound = 1.5 low, high = -bound * torch.ones(2), bound * torch.ones(2) prior = distributions.Uniform(low=low, high=high) # def potential_function(inputs_dict): # parameters = next(iter(inputs_dict.values())) # return -(likelihood.log_prob(parameters) + prior.log_prob(parameters).sum()) prior = distributions.Uniform(low=-5 * torch.ones(4), high=2 * torch.ones(4)) from nsf import distributions as distributions_ likelihood = distributions_.LotkaVolterraOscillating() potential_function = PotentialFunction(likelihood, prior) # kernel = Slice(potential_function=potential_function) from pyro.infer.mcmc import HMC, NUTS # kernel = HMC(potential_fn=potential_function) kernel = NUTS(potential_fn=potential_function) num_chains = 3 sampler = MCMC( kernel=kernel, num_samples=10000 // num_chains, warmup_steps=200, initial_params={"": torch.zeros(num_chains, 4)}, num_chains=num_chains, ) sampler.run() samples = next(iter(sampler.get_samples().values())) utils.plot_hist_marginals(utils.tensor2numpy(samples), ground_truth=utils.tensor2numpy(loc), lims=[-6, 3]) # plt.show() plt.savefig("/home/conor/Dropbox/phd/projects/lfi/out/mcmc.pdf") plt.close()
def gridimshow(image, ax): if image.shape[0] == 1: image = utils.tensor2numpy(image[0, ...]) ax.imshow(1 - image, cmap='Greys') else: image = utils.tensor2numpy(image.permute(1, 2, 0)) ax.imshow(image) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.tick_params(axis='both', length=0) ax.set_xticklabels('') ax.set_yticklabels('')
def simulate(self, parameters): parameters = utils.tensor2numpy(parameters) assert parameters.shape[1] == 3, "parameter must be 3-dimensional" p1, p2, p3 = parameters[:, 0:1], parameters[:, 1:2], parameters[:, 2:3] N = parameters.shape[0] # service times (uniformly distributed) sts = (p2 - p1) * np.random.rand(N, self.n_sim_steparameters) + p1 # inter-arrival times (exponentially distributed) iats = -np.log(1.0 - np.random.rand(N, self.n_sim_steparameters)) / p3 # arrival times ats = np.cumsum(iats, axis=1) # inter-departure times idts = np.empty([N, self.n_sim_steparameters], dtype=float) idts[:, 0] = sts[:, 0] + ats[:, 0] # departure times dts = np.empty([N, self.n_sim_steparameters], dtype=float) dts[:, 0] = idts[:, 0] for i in range(1, self.n_sim_steparameters): idts[:, i] = sts[:, i] + np.maximum(0.0, ats[:, i] - dts[:, i - 1]) dts[:, i] = dts[:, i - 1] + idts[:, i] self.num_total_simulations += N if self._summarize_observations: idts = self._summarizer(idts) return torch.Tensor(idts)
def simulate(self, parameters): """ Generates observations for the given batch of parameters. :param parameters: torch.Tensor Batch of parameters. :return: torch.Tensor Batch of observations. """ # Run simulator in NumPy. if isinstance(parameters, torch.Tensor): parameters = utils.tensor2numpy(parameters) # If we have a single parameter then view it as a batch of one. if parameters.ndim == 1: return self.simulate(parameters[None, ...]) num_simulations = parameters.shape[0] # Keep track of total simulations. self.num_total_simulations += num_simulations # Run simulator. a = np.pi * np.random.rand(num_simulations) - np.pi / 2 r = 0.01 * np.random.randn(num_simulations) + 0.1 p = np.column_stack([r * np.cos(a) + 0.25, r * np.sin(a)]) s = (1 / np.sqrt(2)) * np.column_stack([ -np.abs(parameters[:, 0] + parameters[:, 1]), (-parameters[:, 0] + parameters[:, 1]), ]) return torch.Tensor(p + s)
def test_(): task = "lotka-volterra" simulator, prior = simulators.get_simulator_and_prior(task) parameter_dim, observation_dim = ( simulator.parameter_dim, simulator.observation_dim, ) true_observation = simulator.get_ground_truth_observation() classifier = utils.get_classifier("mlp", parameter_dim, observation_dim) ratio_estimator = SRE( simulator=simulator, true_observation=true_observation, classifier=classifier, prior=prior, num_atoms=-1, mcmc_method="slice-np", retrain_from_scratch_each_round=False, ) num_rounds, num_simulations_per_round = 10, 1000 ratio_estimator.run_inference( num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round) samples = ratio_estimator.sample_posterior(num_samples=2500) samples = utils.tensor2numpy(samples) figure = utils.plot_hist_marginals( data=samples, ground_truth=utils.tensor2numpy( simulator.get_ground_truth_parameters()).reshape(-1), lims=[-4, 4], ) figure.savefig( os.path.join(utils.get_output_root(), "corner-posterior-ratio.pdf")) mmds = ratio_estimator.summary["mmds"] if mmds: figure, axes = plt.subplots(1, 1) axes.plot( np.arange(0, num_rounds * num_simulations_per_round, num_simulations_per_round), np.array(mmds), "-o", linewidth=2, ) figure.savefig(os.path.join(utils.get_output_root(), "mmd-ratio.pdf"))
def log_prob(self, observations, parameters): """ Likelihood is proportional to a product of self._num_observations_per_parameter 2D Gaussians and so log likelihood can be computed analytically. :param observations: torch.Tensor [batch_size, observation_dim] Batch of observations. :param parameters: torch.Tensor [batch_size, parameter_dim] Batch of parameters. :return: torch.Tensor [batch_size] Log likelihood log p(x | theta) for each item in the batch. """ if isinstance(parameters, torch.Tensor): parameters = utils.tensor2numpy(parameters) if isinstance(observations, torch.Tensor): observations = utils.tensor2numpy(observations) if observations.ndim == 1 and parameters.ndim == 1: observations, parameters = ( observations.reshape(1, -1), parameters.reshape(1, -1), ) m0, m1, s0, s1, r = self._unpack_params(parameters) logdet = np.log(s0) + np.log(s1) + 0.5 * np.log(1.0 - r**2) observations = observations.reshape( [observations.shape[0], self._num_observations_per_parameter, 2]) us = np.empty_like(observations) us[:, :, 0] = (observations[:, :, 0] - m0) / s0 us[:, :, 1] = (observations[:, :, 1] - m1 - s1 * r * us[:, :, 0]) / (s1 * np.sqrt(1.0 - r**2)) us = us.reshape( [us.shape[0], 2 * self._num_observations_per_parameter]) L = (np.sum(scipy.stats.norm.logpdf(us), axis=1) - self._num_observations_per_parameter * logdet[:, 0]) return L
def test_change(self): model_list = os.listdir( os.path.join(self.result_dir, self.dataset, 'model')) if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1].split('/')[-1]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) print("[*] Load SUCCESS") else: print("[*] Load FAILURE") return self.genA2B.eval(), self.genB2A.eval() for n, (real_A, fname) in enumerate(self.testA_loader()): real_A = np.array([real_A[0].reshape(3, 256, 256)]).astype("float32") real_A = to_variable(real_A) fake_A2B, _, _ = self.genA2B(real_A) A2B = RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))) cv2.imwrite( os.path.join( self.result_dir, self.dataset, 'test', 'testA2B', '%s_fake.%s' % (fname.split('.')[0], fname.split('.')[-1])), A2B * 255.0) for n, (real_B, fname) in enumerate(self.testB_loader()): real_B = np.array([real_B[0].reshape(3, 256, 256)]).astype("float32") real_B = to_variable(real_B) fake_B2A, _, _ = self.genB2A(real_B) B2A = RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))) cv2.imwrite( os.path.join( self.result_dir, self.dataset, 'test', 'testB2A', '%s_fake.%s' % (fname.split('.')[0], fname.split('.')[-1])), B2A * 255.0)
def _test(): # prior = MG1Uniform(low=torch.zeros(3), high=torch.Tensor([10, 10, 1 / 3])) # uniform = distributions.Uniform( # low=torch.zeros(3), high=torch.Tensor([10, 10, 1 / 3]) # ) # x = torch.Tensor([10, 20, 1 / 3]).reshape(1, -1) # print(uniform.log_prob(x)) # print(prior.log_prob(x)) d = LotkaVolterraOscillating() samples = d.sample((1000,)) utils.plot_hist_marginals(utils.tensor2numpy(samples), lims=[-6, 3]) plt.show()
def simulate(self, parameters): parameters = utils.tensor2numpy(parameters) parameters = np.exp(parameters) observations = [] for i, parameter in enumerate(parameters): try: self._jump_process.reset(self._initial_populations, parameter) states = self._jump_process.simulate_for_time( self._dt, self._duration, max_n_steps=self._max_num_steps) observations.append(torch.Tensor(states.flatten())) except SimTooLongException: observations.append(None) self.num_total_simulations += 1 if self._summarize_observations: return self._summarizer(observations) return observations
def simulate(self, parameters): """ Generates observations for the given batch of parameters. :param parameters: torch.Tensor Batch of parameters. :return: torch.Tensor Batch of observations. """ # Run simulator in NumPy. if isinstance(parameters, torch.Tensor): parameters = utils.tensor2numpy(parameters) # If we have a single parameter then view it as a batch of one. if parameters.ndim == 1: return self.simulate(parameters[np.newaxis, :])[0] num_simulations = parameters.shape[0] # Keep track of total simulations. self.num_total_simulations += num_simulations # Run simulator to generate self._num_observations_per_parameter # observations from a 2D Gaussian parameterized by the 5 given parameters. m0, m1, s0, s1, r = self._unpack_params(parameters) us = np.random.randn(num_simulations, self._num_observations_per_parameter, 2) observations = np.empty_like(us) observations[:, :, 0] = s0 * us[:, :, 0] + m0 observations[:, :, 1] = ( s1 * (r * us[:, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, 1]) + m1) mean, std = self._get_observation_normalization_parameters() return (torch.Tensor( observations.reshape( [num_simulations, 2 * self._num_observations_per_parameter])) - mean.reshape(1, -1)) / std.reshape(1, -1)
def viz_to_tb(dataloader, writer, num_classes, display_num=4): from collections import defaultdict from torchvision.utils import make_grid import numpy as np from utils import tensor2numpy labels_count = {k: 0 for k in range(num_classes)} imgs_dict = defaultdict(list) dl_iter = iter(dataloader) while all([v < display_num for v in labels_count.values()]): inputs, labels = next(dl_iter) for input_, label in zip(inputs, labels): label = int(label) if labels_count[label] < display_num: imgs_dict[label].append(input_) labels_count[label] += 1 for label, imgs in imgs_dict.items(): img_grid = make_grid(imgs) img_grid = tensor2numpy(img_grid) # img_grid = (img_grid * 255).astype(np.uint8) writer.add_image(f'example image for label {label}', img_grid, dataformats='HWC')
def simulate(self, parameters): parameters = utils.tensor2numpy(parameters) observations = self._summarizer.calc(self._simulator.sim(parameters)) return torch.Tensor(observations)
def __init__( self, simulator, prior, true_observation, classifier, num_atoms=-1, mcmc_method="slice-np", summary_net=None, retrain_from_scratch_each_round=False, summary_writer=None, ): """ :param simulator: Python object with 'simulate' method which takes a torch.Tensor of parameter values, and returns a simulation result for each parameter as a torch.Tensor. :param prior: Distribution object with 'log_prob' and 'sample' methods. :param true_observation: torch.Tensor containing the observation x0 for which to perform inference on the posterior p(theta | x0). :param classifier: Binary classifier in the form of an nets.Module. Takes as input (x, theta) pairs and outputs pre-sigmoid activations. :param num_atoms: int Number of atoms to use for classification. If -1, use all other parameters in minibatch. :param summary_net: Optional network which may be used to produce feature vectors f(x) for high-dimensional observations. :param retrain_from_scratch_each_round: Whether to retrain the conditional density estimator for the posterior from scratch each round. """ self._simulator = simulator self._true_observation = true_observation self._classifier = classifier self._prior = prior assert isinstance(num_atoms, int), "Number of atoms must be an integer." self._num_atoms = num_atoms self._mcmc_method = mcmc_method # We may want to summarize high-dimensional observations. # This may be either a fixed or learned transformation. if summary_net is None: self._summary_net = nn.Identity() else: self._summary_net = summary_net # Defining the potential function as an object means Pyro's MCMC scheme # can pickle it to be used across multiple chains in parallel, even if # the potential function requires evaluating a neural likelihood as is the # case here. self._potential_function = NeuralPotentialFunction( classifier, prior, true_observation) # TODO: decide on Slice Sampling implementation target_log_prob = (lambda parameters: self._classifier( torch.cat( (torch.Tensor(parameters), self._true_observation)).reshape( 1, -1)).item() + self._prior.log_prob( torch.Tensor(parameters)).sum().item()) self._classifier.eval() self.posterior_sampler = SliceSampler( utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1), lp_f=target_log_prob, thin=10, ) self._classifier.train() self._retrain_from_scratch_each_round = retrain_from_scratch_each_round # If we're retraining from scratch each round, # keep a copy of the original untrained model for reinitialization. if retrain_from_scratch_each_round: self._untrained_classifier = deepcopy(classifier) else: self._untrained_classifier = None # Need somewhere to store (parameter, observation) pairs from each round. self._parameter_bank, self._observation_bank = [], [] # Each SRE run has an associated log directory for TensorBoard output. if summary_writer is None: log_dir = os.path.join(utils.get_log_root(), "sre", simulator.name, utils.get_timestamp()) self._summary_writer = SummaryWriter(log_dir) else: self._summary_writer = summary_writer # Each run also has a dictionary of summary statistics which are populated # over the course of training. self._summary = { "mmds": [], "median-observation-distances": [], "negative-log-probs-true-parameters": [], "neural-net-fit-times": [], "mcmc-times": [], "epochs": [], "best-validation-log-probs": [], }
def train(self): self.genA2B.train(), self.genB2A.train() self.disGA.train(), self.disGB.train() self.disLA.train(), self.disLB.train() start_iter = 1 if self.resume: model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) if not len(model_list) == 0: model_list.sort() start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) print(" [*] Load SUCCESS") if self.decay_flag and start_iter > (self.iteration // 2): self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \ * (start_iter - self.iteration // 2) self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) \ * (start_iter - self.iteration // 2) # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): if self.decay_flag and step > (self.iteration // 2): self.G_optim.param_groups[0]['lr'] -= ( self.lr / (self.iteration // 2)) self.D_optim.param_groups[0]['lr'] -= ( self.lr / (self.iteration // 2)) try: real_A, _ = trainA_iter.next() # noqa: F821 except Exception: trainA_iter = iter(self.trainA_loader) real_A, _ = trainA_iter.next() try: real_B, _ = trainB_iter.next() # noqa: F821 except Exception: trainB_iter = iter(self.trainB_loader) real_B, _ = trainB_iter.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) # Update D self.D_optim.zero_grad() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to( self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to( self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to( self.device)) + self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device)) D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to( self.device)) + self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device)) D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to( self.device)) + self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device)) D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to( self.device)) + self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device)) D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to( self.device)) + self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device)) D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to( self.device)) + self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device)) D_loss_A = self.adv_weight * \ (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * \ (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.step() # Update G self.G_optim.zero_grad() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss( fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) G_ad_loss_LA = self.MSE_loss( fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device)) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device)) G_ad_loss_GB = self.MSE_loss( fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device)) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device)) G_ad_loss_LB = self.MSE_loss( fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device)) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to( self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device)) G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to( self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device)) G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \ self.cycle_weight * G_recon_loss_A + self.identity_weight * \ G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \ self.cycle_weight * G_recon_loss_B + self.identity_weight * \ G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.step() # clip parameter of AdaILN and ILN, applied after optimizer step self.genA2B.apply(self.Rho_clipper) self.genB2A.apply(self.Rho_clipper) msg = "[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss) print(msg) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval() self.disGA.eval(), self.disGB.eval() self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): try: real_A, _ = trainA_iter.next() except Exception: trainA_iter = iter(self.trainA_loader) real_A, _ = trainA_iter.next() try: real_B, _ = trainB_iter.next() except Exception: trainB_iter = iter(self.trainB_loader) real_B, _ = trainB_iter.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = testA_iter.next() # noqa: F821 except Exception: testA_iter = iter(self.testA_loader) real_A, _ = testA_iter.next() try: real_B, _ = testB_iter.next() # noqa: F821 except Exception: testB_iter = iter(self.testB_loader) real_B, _ = testB_iter.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train() self.disGA.train(), self.disGB.train() self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: params = {} params['genA2B'] = self.genA2B.state_dict() params['genB2A'] = self.genB2A.state_dict() params['disGA'] = self.disGA.state_dict() params['disGB'] = self.disGB.state_dict() params['disLA'] = self.disLA.state_dict() params['disLB'] = self.disLB.state_dict() torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
def test(self): model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1].split('_')[-1].split('.')[0]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) print(" [*] Load SUCCESS") else: print(" [*] Load FAILURE") return self.genA2B.eval(), self.genB2A.eval() for n, (real_A, _) in enumerate(self.testA_loader): real_A = real_A.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0) for n, (real_B, _) in enumerate(self.testB_loader): real_B = real_B.to(self.device) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
def __init__( self, simulator, prior, true_observation, neural_posterior, num_atoms=-1, use_combined_loss=False, train_with_mcmc=False, mcmc_method="slice-np", summary_net=None, retrain_from_scratch_each_round=False, discard_prior_samples=False, summary_writer=None, ): """ :param simulator: Python object with 'simulate' method which takes a torch.Tensor of parameter values, and returns a simulation result for each parameter as a torch.Tensor. :param prior: Distribution Distribution object with 'log_prob' and 'sample' methods. :param true_observation: torch.Tensor [observation_dim] or [1, observation_dim] True observation x0 for which to perform inference on the posterior p(theta | x0). :param neural_posterior: nets.Module Conditional density estimator q(theta | x) with 'log_prob' and 'sample' methods. :param num_atoms: int Number of atoms to use for classification. If -1, use all other parameters in minibatch. :param use_combined_loss: bool Whether to jointly train prior samples using maximum likelihood. Useful to prevent density leaking when using box uniform priors. :param train_with_mcmc: bool Whether to sample using MCMC instead of i.i.d. sampling at the end of each round :param mcmc_method: str MCMC method to use if 'train_with_mcmc' is True. One of ['slice-numpy', 'hmc', 'nuts']. :param summary_net: nets.Module Optional network which may be used to produce feature vectors f(x) for high-dimensional observations. :param retrain_from_scratch_each_round: bool Whether to retrain the conditional density estimator for the posterior from scratch each round. :param discard_prior_samples: bool Whether to discard prior samples from round two onwards. :param summary_writer: SummaryWriter Optionally pass summary writer. If None, will create one internally. """ self._simulator = simulator self._prior = prior self._true_observation = true_observation self._neural_posterior = neural_posterior assert isinstance(num_atoms, int), "Number of atoms must be an integer." self._num_atoms = num_atoms self._use_combined_loss = use_combined_loss # We may want to summarize high-dimensional observations. # This may be either a fixed or learned transformation. if summary_net is None: self._summary_net = nn.Identity() else: self._summary_net = summary_net self._mcmc_method = mcmc_method self._train_with_mcmc = train_with_mcmc # HMC and NUTS from Pyro. # Defining the potential function as an object means Pyro's MCMC scheme # can pickle it to be used across multiple chains in parallel, even if # the potential function requires evaluating a neural likelihood as is the # case here. self._potential_function = NeuralPotentialFunction( neural_posterior, prior, self._true_observation ) # Axis-aligned slice sampling implementation in NumPy target_log_prob = ( lambda parameters: self._neural_posterior.log_prob( inputs=torch.Tensor(parameters).reshape(1, -1), context=self._true_observation.reshape(1, -1), ).item() if not np.isinf(self._prior.log_prob(torch.Tensor(parameters)).sum().item()) else -np.inf ) self._neural_posterior.eval() self.posterior_sampler = SliceSampler( utils.tensor2numpy(self._prior.sample((1,))).reshape(-1), lp_f=target_log_prob, thin=10, ) self._neural_posterior.train() self._retrain_from_scratch_each_round = retrain_from_scratch_each_round # If we're retraining from scratch each round, # keep a copy of the original untrained model for reinitialization. self._untrained_neural_posterior = deepcopy(neural_posterior) self._discard_prior_samples = discard_prior_samples # Need somewhere to store (parameter, observation) pairs from each round. self._parameter_bank, self._observation_bank, self._prior_masks = [], [], [] self._model_bank = [] self._total_num_generated_examples = 0 # Each APT run has an associated log directory for TensorBoard output. if summary_writer is None: log_dir = os.path.join( utils.get_log_root(), "apt", simulator.name, utils.get_timestamp() ) self._summary_writer = SummaryWriter(log_dir) else: self._summary_writer = summary_writer # Each run also has a dictionary of summary statistics which are populated # over the course of training. self._summary = { "mmds": [], "median-observation-distances": [], "negative-log-probs-true-parameters": [], "neural-net-fit-times": [], "epochs": [], "best-validation-log-probs": [], "rejection-sampling-acceptance-rates": [], }
def run(seed): assert torch.cuda.is_available() device = torch.device('cuda') torch.set_default_tensor_type('torch.cuda.FloatTensor') np.random.seed(seed) torch.manual_seed(seed) # Create training data. data_transform = tvtransforms.Compose( [tvtransforms.ToTensor(), tvtransforms.Lambda(torch.bernoulli)]) if args.dataset_name == 'mnist': dataset = datasets.MNIST(root=os.path.join(utils.get_data_root(), 'mnist'), train=True, download=True, transform=data_transform) test_dataset = datasets.MNIST(root=os.path.join( utils.get_data_root(), 'mnist'), train=False, download=True, transform=data_transform) elif args.dataset_name == 'fashion-mnist': dataset = datasets.FashionMNIST(root=os.path.join( utils.get_data_root(), 'fashion-mnist'), train=True, download=True, transform=data_transform) test_dataset = datasets.FashionMNIST(root=os.path.join( utils.get_data_root(), 'fashion-mnist'), train=False, download=True, transform=data_transform) elif args.dataset_name == 'omniglot': dataset = data_.OmniglotDataset(split='train', transform=data_transform) test_dataset = data_.OmniglotDataset(split='test', transform=data_transform) elif args.dataset_name == 'emnist': rotate = partial(tvF.rotate, angle=-90) hflip = tvF.hflip data_transform = tvtransforms.Compose([ tvtransforms.Lambda(rotate), tvtransforms.Lambda(hflip), tvtransforms.ToTensor(), tvtransforms.Lambda(torch.bernoulli) ]) dataset = datasets.EMNIST(root=os.path.join(utils.get_data_root(), 'emnist'), split='letters', train=True, transform=data_transform, download=True) test_dataset = datasets.EMNIST(root=os.path.join( utils.get_data_root(), 'emnist'), split='letters', train=False, transform=data_transform, download=True) else: raise ValueError if args.dataset_name == 'omniglot': split = -1345 elif args.dataset_name == 'emnist': split = -20000 else: split = -10000 indices = np.arange(len(dataset)) np.random.shuffle(indices) train_indices, val_indices = indices[:split], indices[split:] train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) train_loader = data.DataLoader( dataset=dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=4 if args.dataset_name == 'emnist' else 0) train_generator = data_.batch_generator(train_loader) val_loader = data.DataLoader(dataset=dataset, batch_size=1024, sampler=val_sampler, shuffle=False, drop_last=False) val_batch = next(iter(val_loader))[0] test_loader = data.DataLoader( test_dataset, batch_size=16, shuffle=False, drop_last=False, ) # from matplotlib import pyplot as plt # from experiments import cutils # from torchvision.utils import make_grid # fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # cutils.gridimshow(make_grid(val_batch[:64], nrow=8), ax) # plt.show() # quit() def create_linear_transform(): if args.linear_type == 'lu': return transforms.CompositeTransform([ transforms.RandomPermutation(args.latent_features), transforms.LULinear(args.latent_features, identity_init=True) ]) elif args.linear_type == 'svd': return transforms.SVDLinear(args.latent_features, num_householder=4, identity_init=True) elif args.linear_type == 'perm': return transforms.RandomPermutation(args.latent_features) else: raise ValueError def create_base_transform(i, context_features=None): if args.prior_type == 'affine-coupling': return transforms.AffineCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm)) elif args.prior_type == 'rq-coupling': return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm), num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, apply_unconditional_transform=args. apply_unconditional_transform, ) elif args.prior_type == 'affine-autoregressive': return transforms.MaskedAffineAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) elif args.prior_type == 'rq-autoregressive': return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) else: raise ValueError # --------------- # prior # --------------- def create_prior(): if args.prior_type == 'standard-normal': prior = distributions_.StandardNormal((args.latent_features, )) else: distribution = distributions_.StandardNormal( (args.latent_features, )) transform = transforms.CompositeTransform([ transforms.CompositeTransform( [create_linear_transform(), create_base_transform(i)]) for i in range(args.num_flow_steps) ]) transform = transforms.CompositeTransform( [transform, create_linear_transform()]) prior = flows.Flow(transform, distribution) return prior # --------------- # inputs encoder # --------------- def create_inputs_encoder(): if args.approximate_posterior_type == 'diagonal-normal': inputs_encoder = None else: inputs_encoder = nn_.ConvEncoder( context_features=args.context_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) return inputs_encoder # --------------- # approximate posterior # --------------- def create_approximate_posterior(): if args.approximate_posterior_type == 'diagonal-normal': context_encoder = nn_.ConvEncoder( context_features=args.context_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) approximate_posterior = distributions_.ConditionalDiagonalNormal( shape=[args.latent_features], context_encoder=context_encoder) else: context_encoder = nn.Linear(args.context_features, 2 * args.latent_features) distribution = distributions_.ConditionalDiagonalNormal( shape=[args.latent_features], context_encoder=context_encoder) transform = transforms.CompositeTransform([ transforms.CompositeTransform([ create_linear_transform(), create_base_transform( i, context_features=args.context_features) ]) for i in range(args.num_flow_steps) ]) transform = transforms.CompositeTransform( [transform, create_linear_transform()]) approximate_posterior = flows.Flow( transforms.InverseTransform(transform), distribution) return approximate_posterior # --------------- # likelihood # --------------- def create_likelihood(): latent_decoder = nn_.ConvDecoder( latent_features=args.latent_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) likelihood = distributions_.ConditionalIndependentBernoulli( shape=[1, 28, 28], context_encoder=latent_decoder) return likelihood prior = create_prior() approximate_posterior = create_approximate_posterior() likelihood = create_likelihood() inputs_encoder = create_inputs_encoder() model = vae.VariationalAutoencoder( prior=prior, approximate_posterior=approximate_posterior, likelihood=likelihood, inputs_encoder=inputs_encoder) # with torch.no_grad(): # # elbo = model.stochastic_elbo(val_batch[:16].to(device)).mean() # # print(elbo) # elbo = model.stochastic_elbo(val_batch[:16].to(device), num_samples=100).mean() # print(elbo) # log_prob = model.log_prob_lower_bound(val_batch[:16].to(device), num_samples=1200).mean() # print(log_prob) # quit() n_params = utils.get_num_parameters(model) print('There are {} trainable parameters in this model.'.format(n_params)) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=args.num_training_steps, eta_min=0) def get_kl_multiplier(step): if args.kl_multiplier_schedule == 'constant': return args.kl_multiplier_initial elif args.kl_multiplier_schedule == 'linear': multiplier = min( step / (args.num_training_steps * args.kl_warmup_fraction), 1.) return args.kl_multiplier_initial * (1. + multiplier) # create summary writer and write to log directory timestamp = cutils.get_timestamp() if cutils.on_cluster(): timestamp += '||{}'.format(os.environ['SLURM_JOB_ID']) log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp) while True: try: writer = SummaryWriter(log_dir=log_dir, max_queue=20) break except FileExistsError: sleep(5) filename = os.path.join(log_dir, 'config.json') with open(filename, 'w') as file: json.dump(vars(args), file) best_val_elbo = -np.inf tbar = tqdm(range(args.num_training_steps)) for step in tbar: model.train() optimizer.zero_grad() scheduler.step(step) batch = next(train_generator)[0].to(device) elbo = model.stochastic_elbo(batch, kl_multiplier=get_kl_multiplier(step)) loss = -torch.mean(elbo) loss.backward() optimizer.step() if (step + 1) % args.monitor_interval == 0: model.eval() with torch.no_grad(): elbo = model.stochastic_elbo(val_batch.to(device)) mean_val_elbo = elbo.mean() if mean_val_elbo > best_val_elbo: best_val_elbo = mean_val_elbo path = os.path.join( cutils.get_checkpoint_root(), '{}-best-val-{}.t'.format(args.dataset_name, timestamp)) torch.save(model.state_dict(), path) writer.add_scalar(tag='val-elbo', scalar_value=mean_val_elbo, global_step=step) writer.add_scalar(tag='best-val-elbo', scalar_value=best_val_elbo, global_step=step) with torch.no_grad(): samples = model.sample(64) fig, ax = plt.subplots(figsize=(10, 10)) cutils.gridimshow(make_grid(samples.view(64, 1, 28, 28), nrow=8), ax) writer.add_figure(tag='vae-samples', figure=fig, global_step=step) plt.close() # load best val model path = os.path.join( cutils.get_checkpoint_root(), '{}-best-val-{}.t'.format(args.dataset_name, timestamp)) model.load_state_dict(torch.load(path)) model.eval() np.random.seed(5) torch.manual_seed(5) # compute elbo on test set with torch.no_grad(): elbo = torch.Tensor([]) log_prob_lower_bound = torch.Tensor([]) for batch in tqdm(test_loader): elbo_ = model.stochastic_elbo(batch[0].to(device)) elbo = torch.cat([elbo, elbo_]) log_prob_lower_bound_ = model.log_prob_lower_bound( batch[0].to(device), num_samples=1000) log_prob_lower_bound = torch.cat( [log_prob_lower_bound, log_prob_lower_bound_]) path = os.path.join( log_dir, '{}-prior-{}-posterior-{}-elbo.npy'.format( args.dataset_name, args.prior_type, args.approximate_posterior_type)) np.save(path, utils.tensor2numpy(elbo)) path = os.path.join( log_dir, '{}-prior-{}-posterior-{}-log-prob-lower-bound.npy'.format( args.dataset_name, args.prior_type, args.approximate_posterior_type)) np.save(path, utils.tensor2numpy(log_prob_lower_bound)) # save elbo and log prob lower bound mean_elbo = elbo.mean() std_elbo = elbo.std() mean_log_prob_lower_bound = log_prob_lower_bound.mean() std_log_prob_lower_bound = log_prob_lower_bound.std() s = 'ELBO: {:.2f} +- {:.2f}, LOG PROB LOWER BOUND: {:.2f} +- {:.2f}'.format( mean_elbo.item(), 2 * std_elbo.item() / np.sqrt(len(test_dataset)), mean_log_prob_lower_bound.item(), 2 * std_log_prob_lower_bound.item() / np.sqrt(len(test_dataset))) filename = os.path.join(log_dir, 'test-results.txt') with open(filename, 'w') as file: file.write(s)
def __init__( self, simulator, prior, true_observation, neural_likelihood, mcmc_method="slice-np", summary_writer=None, ): """ :param simulator: Python object with 'simulate' method which takes a torch.Tensor of parameter values, and returns a simulation result for each parameter as a torch.Tensor. :param prior: Distribution object with 'log_prob' and 'sample' methods. :param true_observation: torch.Tensor containing the observation x0 for which to perform inference on the posterior p(theta | x0). :param neural_likelihood: Conditional density estimator q(x | theta) in the form of an nets.Module. Must have 'log_prob' and 'sample' methods. :param mcmc_method: MCMC method to use for posterior sampling. Must be one of ['slice', 'hmc', 'nuts']. """ self._simulator = simulator self._prior = prior self._true_observation = true_observation self._neural_likelihood = neural_likelihood self._mcmc_method = mcmc_method # Defining the potential function as an object means Pyro's MCMC scheme # can pickle it to be used across multiple chains in parallel, even if # the potential function requires evaluating a neural likelihood as is the # case here. self._potential_function = NeuralPotentialFunction( neural_likelihood=self._neural_likelihood, prior=self._prior, true_observation=self._true_observation, ) # TODO: decide on Slice Sampling implementation target_log_prob = (lambda parameters: self._neural_likelihood.log_prob( inputs=self._true_observation.reshape(1, -1), context=torch.Tensor(parameters).reshape(1, -1), ).item() + self._prior.log_prob(torch.Tensor(parameters)).sum().item()) self._neural_likelihood.eval() self.posterior_sampler = SliceSampler( utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1), lp_f=target_log_prob, thin=10, ) self._neural_likelihood.train() # Need somewhere to store (parameter, observation) pairs from each round. self._parameter_bank, self._observation_bank = [], [] # Each SNL run has an associated log directory for TensorBoard output. if summary_writer is None: log_dir = os.path.join(utils.get_log_root(), "snl", simulator.name, utils.get_timestamp()) self._summary_writer = SummaryWriter(log_dir) else: self._summary_writer = summary_writer # Each run also has a dictionary of summary statistics which are populated # over the course of training. self._summary = { "mmds": [], "median-observation-distances": [], "negative-log-probs-true-parameters": [], "neural-net-fit-times": [], "mcmc-times": [], "epochs": [], "best-validation-log-probs": [], }
"""Upload customed cnn model""" cnn = CNN(256, 256, 3, 101) cnn.load_weights('weights/custom/cnn_plus.h5') plot_model(cnn, to_file='./model.png', show_shapes=True, show_layer_names=True) train_model(2, 'cnn_plus', cnn, srgan) #filepath="./cnn_weights.h5" #checkpoint = ModelCheckpoint(filepath, monitor='accuracy', verbose=1, save_best_only=True, mode='max') #callbacks_list = [checkpoint] """Prepare and train on a batch of data and labels, 10 iterations""" for i in range(2): train_set = devide(24, 2, 2) X = tensor2numpy('./data/', train_set, srgan) x = [X[i] for i in X.keys()] train = np.array(x, dtype = "float64") y = create_onehot(X) history = cnn.fit(train, y, batch_size=32, epochs=5, callbacks=callbacks_list, validation_split=0.2) # Plot training & validation accuracy values plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['Train', 'Test'], loc='upper left') plt.show() """Upload, use transfer learning""" VGG=VGG19(input_shape=(224,224,3),include_top=False,weights='imagenet')
def train(self): epochs = 1000 self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() print('training start !') start_time = time.time() '''加载预训练模型''' if self.pretrain: str_genA2B = "Parameters/genA2B%03d.pdparams" % (self.start - 1) str_genB2A = "Parameters/genB2A%03d.pdparams" % (self.start - 1) str_disGA = "Parameters/disGA%03d.pdparams" % (self.start - 1) str_disGB = "Parameters/disGB%03d.pdparams" % (self.start - 1) str_disLA = "Parameters/disLA%03d.pdparams" % (self.start - 1) str_disLB = "Parameters/disLB%03d.pdparams" % (self.start - 1) genA2B_para, gen_A2B_opt = fluid.load_dygraph(str_genA2B) genB2A_para, gen_B2A_opt = fluid.load_dygraph(str_genB2A) disGA_para, disGA_opt = fluid.load_dygraph(str_disGA) disGB_para, disGB_opt = fluid.load_dygraph(str_disGB) disLA_para, disLA_opt = fluid.load_dygraph(str_disLA) disLB_para, disLB_opt = fluid.load_dygraph(str_disLB) self.genA2B.load_dict(genA2B_para) self.genB2A.load_dict(genB2A_para) self.disGA.load_dict(disGA_para) self.disGB.load_dict(disGB_para) self.disLA.load_dict(disLA_para) self.disLB.load_dict(disLB_para) for epoch in range(self.start, epochs): for block_id, data in enumerate(self.train_reader()): real_A = np.array([x[0] for x in data], np.float32) real_B = np.array([x[1] for x in data], np.float32) real_A = totensor(real_A, block_id, 'train') real_B = totensor(real_B, block_id, 'train') # Update D fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = mse_loss(1, real_GA_logit) + mse_loss( 0, fake_GA_logit) D_ad_cam_loss_GA = mse_loss(1, real_GA_cam_logit) + mse_loss( 0, fake_GA_cam_logit) D_ad_loss_LA = mse_loss(1, real_LA_logit) + mse_loss( 0, fake_LA_logit) D_ad_cam_loss_LA = mse_loss(1, real_LA_cam_logit) + mse_loss( 0, fake_LA_cam_logit) D_ad_loss_GB = mse_loss(1, real_GB_logit) + mse_loss( 0, fake_GB_logit) D_ad_cam_loss_GB = mse_loss(1, real_GB_cam_logit) + mse_loss( 0, fake_GB_cam_logit) D_ad_loss_LB = mse_loss(1, real_LB_logit) + mse_loss( 0, fake_LB_logit) D_ad_cam_loss_LB = mse_loss(1, real_LB_cam_logit) + mse_loss( 0, fake_LB_cam_logit) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_opt.minimize(Discriminator_loss) self.disGA.clear_gradients(), self.disGB.clear_gradients( ), self.disLA.clear_gradients(), self.disLB.clear_gradients() # Update G fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) print("fake_A2B.shape:", fake_A2B.shape) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = mse_loss(1, fake_GA_logit) G_ad_cam_loss_GA = mse_loss(1, fake_GA_cam_logit) G_ad_loss_LA = mse_loss(1, fake_LA_logit) G_ad_cam_loss_LA = mse_loss(1, fake_LA_cam_logit) G_ad_loss_GB = mse_loss(1, fake_GB_logit) G_ad_cam_loss_GB = mse_loss(1, fake_GB_cam_logit) G_ad_loss_LB = mse_loss(1, fake_LB_logit) G_ad_cam_loss_LB = mse_loss(1, fake_LB_cam_logit) G_recon_loss_A = self.L1loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1loss(fake_A2A, real_A) G_identity_loss_B = self.L1loss(fake_B2B, real_B) G_cam_loss_A = bce_loss(1, fake_B2A_cam_logit) + bce_loss( 0, fake_A2A_cam_logit) G_cam_loss_B = bce_loss(1, fake_A2B_cam_logit) + bce_loss( 0, fake_B2B_cam_logit) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_opt.minimize(Generator_loss) self.genA2B.clear_gradients(), self.genB2A.clear_gradients() print("[%5d/%5d] time: %4.4f d_loss: %.5f, g_loss: %.5f" % (epoch, block_id, time.time() - start_time, Discriminator_loss.numpy(), Generator_loss.numpy())) print("G_loss_A: %.5f G_loss_B: %.5f" % (G_loss_A.numpy(), G_loss_B.numpy())) print("G_ad_loss_GA: %.5f G_ad_loss_GB: %.5f" % (G_ad_loss_GA.numpy(), G_ad_loss_GB.numpy())) print("G_ad_loss_LA: %.5f G_ad_loss_LB: %.5f" % (G_ad_loss_LA.numpy(), G_ad_loss_LB.numpy())) print("G_cam_loss_A:%.5f G_cam_loss_B:%.5f" % (G_cam_loss_A.numpy(), G_cam_loss_B.numpy())) print("G_recon_loss_A:%.5f G_recon_loss_B:%.5f" % (G_recon_loss_A.numpy(), G_recon_loss_B.numpy())) print("G_identity_loss_A:%.5f G_identity_loss_B:%.5f" % (G_identity_loss_B.numpy(), G_identity_loss_B.numpy())) if epoch % 2 == 1 and block_id % self.print_freq == 0: A2B = np.zeros((self.img_size * 7, 0, 3)) # B2A = np.zeros((self.img_size * 7, 0, 3)) for eval_id, eval_data in enumerate(self.test_reader()): if eval_id == 10: break real_A = np.array([x[0] for x in eval_data], np.float32) real_B = np.array([x[1] for x in eval_data], np.float32) real_A = totensor(real_A, eval_id, 'eval') real_B = totensor(real_B, eval_id, 'eval') fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A( fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B( fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) a = tensor2numpy(denorm(real_A[0])) b = cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size) c = tensor2numpy(denorm(fake_A2A[0])) d = cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size) e = tensor2numpy(denorm(fake_A2B[0])) f = cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size) g = tensor2numpy(denorm(fake_A2B2A[0])) A2B = np.concatenate((A2B, (np.concatenate( (a, b, c, d, e, f, g)) * 255).astype(np.uint8)), 1).astype(np.uint8) A2B = Image.fromarray(A2B) A2B.save('Images/%d_%d.png' % (epoch, block_id)) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train( ), self.disLB.train() if epoch % 4 == 0: fluid.save_dygraph(self.genA2B.state_dict(), "Parameters/genA2B%03d" % (epoch)) fluid.save_dygraph(self.genB2A.state_dict(), "Parameters/genB2A%03d" % (epoch)) fluid.save_dygraph(self.disGA.state_dict(), "Parameters/disGA%03d" % (epoch)) fluid.save_dygraph(self.disGB.state_dict(), "Parameters/disGB%03d" % (epoch)) fluid.save_dygraph(self.disLA.state_dict(), "Parameters/disLA%03d" % (epoch)) fluid.save_dygraph(self.disLB.state_dict(), "Parameters/disLB%03d" % (epoch))
flow = create_flow(args.flow_type) flow.load_state_dict(torch.load(args.input_ckpt)) flow.eval() estimated_cov = np.cov(train_data, rowvar=False) _, pca_v = pca(estimated_cov) c = eval_data @ pca_v[:, 0] fig, ax = plt.subplots(1, 1, figsize=(2, 2)) ax.set_xlim([-4, 4]) ax.set_ylim([-4, 4]) ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel('$z_1$') ax.set_ylabel('$z_2$') with torch.no_grad(): proj = flow.transform_to_noise(torch.FloatTensor(eval_data)) proj = tensor2numpy(proj) s = ax.scatter(proj[:, 0], proj[:, 1], c=c.flat, alpha=0.3) s.set_rasterized(True) if args.output_pdf: fig.tight_layout() fig.savefig(args.output_pdf, bbox_inches='tight', dpi=150) else: plt.show()
def test(self): model_list = os.listdir( os.path.join(self.result_dir, self.dataset, 'model')) if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) print("[*] Load SUCCESS") else: print("[*] Load FAILURE") return self.genA2B.eval(), self.genB2A.eval() for n, (real_A, _) in enumerate(self.testA_loader()): real_A = np.array([real_A.reshape(3, 256, 256)]).astype("float32") real_A = to_variable(real_A) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) A2B = np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0) for n, (real_B, _) in enumerate(self.testB_loader()): real_B = np.array([real_B.reshape(3, 256, 256)]).astype("float32") real_B = to_variable(real_B) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) B2A = np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
tbar.set_description(s) summaries = {'loss': loss.detach()} for summary, value in summaries.items(): writer.add_scalar(tag=summary, scalar_value=value, global_step=step) if (step + 1) % args.visualize_interval == 0: flow.eval() log_density_np = [] for batch in grid_loader: batch = batch.to(device) _, log_density = flow.log_prob(batch) log_density_np = np.concatenate( (log_density_np, utils.tensor2numpy(log_density))) figure, axes = plt.subplots(1, 3, figsize=(7.5, 2.5), sharex=True, sharey=True) cmap = cm.magma axes[0].hist2d(utils.tensor2numpy(train_dataset.data[:, 0]), utils.tensor2numpy(train_dataset.data[:, 1]), range=bounds, bins=512, cmap=cmap, rasterized=False) axes[0].set_xlim(bounds[0])
for summary, value in summaries.items(): writer.add_scalar(tag=summary, scalar_value=value, global_step=step) if (step + 1) % args.visualize_interval == 0: # Plotting aem.eval() aem.set_n_proposal_samples_per_input_validation( args.n_proposal_samples_per_input_validation) log_density_np = [] log_proposal_density_np = [] for batch in grid_loader: batch = batch.to(device) log_density, log_proposal_density, unnormalized_log_density, log_normalizer = aem( batch) log_density_np = np.concatenate(( log_density_np, utils.tensor2numpy(log_density) )) log_proposal_density_np = np.concatenate(( log_proposal_density_np, utils.tensor2numpy(log_proposal_density) )) fig, axs = plt.subplots(1, 3, figsize=(7.5, 2.5)) axs[0].hist2d(train_dataset.data[:, 0], train_dataset.data[:, 1], range=bounds, bins=512, cmap=cm.viridis, rasterized=False) axs[0].set_xticks([]) axs[0].set_yticks([]) axs[1].pcolormesh(grid_dataset.X, grid_dataset.Y, np.exp(log_proposal_density_np).reshape(grid_dataset.X.shape)) axs[1].set_xlim(bounds[0])
def totensor(imgs): imgs = fluid.dygraph.to_variable(imgs) imgs = imgs / 255. imgs = fluid.layers.transpose(imgs, (0,3,1,2)) imgs = fluid.layers.image_resize(imgs, (256,256)) imgs = (imgs - 0.5) / 0.5 return imgs if __name__ == "__main__": gl._init() gl.set_value('rho',0) real_paths = os.listdir(real_source) with fluid.dygraph.guard(): genA2B = ResnetGenerator(in_channels=3, out_channels=3, ngf= 64, n_blocks=4) genA2B_para, gen_A2B_opt = fluid.load_dygraph("Parameters/genA2B124.pdparams") genA2B.load_dict(genA2B_para) count = 0 for real_image_path in real_paths: real_image_path = os.path.join(real_source, real_image_path) img = np.array(Image.open(real_image_path).convert("RGB")).astype(np.float32) img = img[np.newaxis,:,:,:] img = totensor(img) fakeA2B,_,_ = genA2B(img) a = (tensor2numpy(denorm(fakeA2B[0]))*255).astype(np.uint8) a = Image.fromarray(a) save_path = os.path.join(fake_path, "%04d"%(count)+"_fake.png") count += 1 a.save(save_path)
def _summarize(self, round_): # Update summaries. try: mmd = utils.unbiased_mmd_squared( self._parameter_bank[-1], self._simulator.get_ground_truth_posterior_samples( num_samples=1000), ) self._summary["mmds"].append(mmd.item()) except: pass median_observation_distance = torch.median( torch.sqrt( torch.sum( (self._observation_bank[-1] - self._true_observation.reshape(1, -1))**2, dim=-1, ))) self._summary["median-observation-distances"].append( median_observation_distance.item()) negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval( samples=self._parameter_bank[-1], query=self._simulator.get_ground_truth_parameters().reshape(1, -1), ) self._summary["negative-log-probs-true-parameters"].append( negative_log_prob_true_parameters.item()) # Plot most recently sampled parameters in TensorBoard. parameters = utils.tensor2numpy(self._parameter_bank[-1]) figure = utils.plot_hist_marginals( data=parameters, ground_truth=utils.tensor2numpy( self._simulator.get_ground_truth_parameters()).reshape(-1), lims=self._simulator.parameter_plotting_limits, ) self._summary_writer.add_figure(tag="posterior-samples", figure=figure, global_step=round_ + 1) self._summary_writer.add_scalar( tag="epochs-trained", scalar_value=self._summary["epochs"][-1], global_step=round_ + 1, ) self._summary_writer.add_scalar( tag="best-validation-log-prob", scalar_value=self._summary["best-validation-log-probs"][-1], global_step=round_ + 1, ) self._summary_writer.add_scalar( tag="median-observation-distance", scalar_value=self._summary["median-observation-distances"][-1], global_step=round_ + 1, ) self._summary_writer.add_scalar( tag="negative-log-prob-true-parameters", scalar_value=self._summary["negative-log-probs-true-parameters"] [-1], global_step=round_ + 1, ) if self._summary["mmds"]: self._summary_writer.add_scalar( tag="mmd", scalar_value=self._summary["mmds"][-1], global_step=round_ + 1, ) self._summary_writer.flush()
scalar_value=value, global_step=step) # load best val model path = os.path.join(cutils.get_checkpoint_root(), '{}-best-val-{}.t'.format(args.dataset_name, timestamp)) flow.load_state_dict(torch.load(path)) flow.eval() # calculate log-likelihood on test set with torch.no_grad(): log_likelihood = torch.Tensor([]) for batch in tqdm(test_loader): _, log_density = flow.log_prob(batch.to(device)) log_likelihood = torch.cat([log_likelihood, log_density]) path = os.path.join( log_dir, '{}-{}-log-likelihood.npy'.format(args.dataset_name, args.base_transform_type)) np.save(path, utils.tensor2numpy(log_likelihood)) mean_log_likelihood = log_likelihood.mean() std_log_likelihood = log_likelihood.std() # save log-likelihood s = 'Final score for {}: {:.2f} +- {:.2f}'.format( args.dataset_name.capitalize(), mean_log_likelihood.item(), 2 * std_log_likelihood.item() / np.sqrt(len(test_dataset))) print(s) filename = os.path.join(log_dir, 'test-results.txt') with open(filename, 'w') as file: file.write(s)
def train(self): self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 if self.resume: model_list = os.listdir( os.path.join(self.result_dir, self.dataset, 'model')) if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1]) print("[*]load %d" % (iter)) self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) print("[*] Load SUCCESS") # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): real_A = next(self.trainA_loader) real_B = next(self.trainB_loader) real_A = np.array([real_A[0].reshape(3, 256, 256)]).astype("float32") real_B = np.array([real_B[0].reshape(3, 256, 256)]).astype("float32") real_A = to_variable(real_A) real_B = to_variable(real_B) # Update D fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, ones_like(real_GA_logit)) + self.MSE_loss( fake_GA_logit, zeros_like(fake_GA_logit)) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, ones_like(real_GA_cam_logit)) + self.MSE_loss( fake_GA_cam_logit, zeros_like(fake_GA_cam_logit)) D_ad_loss_LA = self.MSE_loss( real_LA_logit, ones_like(real_LA_logit)) + self.MSE_loss( fake_LA_logit, zeros_like(fake_LA_logit)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, ones_like(real_LA_cam_logit)) + self.MSE_loss( fake_LA_cam_logit, zeros_like(fake_LA_cam_logit)) D_ad_loss_GB = self.MSE_loss( real_GB_logit, ones_like(real_GB_logit)) + self.MSE_loss( fake_GB_logit, zeros_like(fake_GB_logit)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, ones_like(real_GB_cam_logit)) + self.MSE_loss( fake_GB_cam_logit, zeros_like(fake_GB_cam_logit)) D_ad_loss_LB = self.MSE_loss( real_LB_logit, ones_like(real_LB_logit)) + self.MSE_loss( fake_LB_logit, zeros_like(fake_LB_logit)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, ones_like(real_LB_cam_logit)) + self.MSE_loss( fake_LB_cam_logit, zeros_like(fake_LB_cam_logit)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) self.genB2A.clear_gradients() self.genA2B.clear_gradients() self.disGA.clear_gradients() self.disLA.clear_gradients() self.disGB.clear_gradients() self.disLB.clear_gradients() self.D_optim.clear_gradients() # Update G fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, ones_like(fake_GA_logit)) G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, ones_like(fake_GA_cam_logit)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, ones_like(fake_LA_logit)) G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, ones_like(fake_LA_cam_logit)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, ones_like(fake_GB_logit)) G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, ones_like(fake_GB_cam_logit)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, ones_like(fake_LB_logit)) G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, ones_like(fake_LB_cam_logit)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, ones_like(fake_B2A_cam_logit)) + self.BCE_loss( fake_A2A_cam_logit, zeros_like(fake_A2A_cam_logit)) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, ones_like(fake_A2B_cam_logit)) + self.BCE_loss( fake_B2B_cam_logit, zeros_like(fake_B2B_cam_logit)) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) self.genB2A.clear_gradients() self.genA2B.clear_gradients() self.disGA.clear_gradients() self.disLA.clear_gradients() self.disGB.clear_gradients() self.disLB.clear_gradients() self.G_optim.clear_gradients() self.Rho_clipper(self.genA2B) self.Rho_clipper(self.genB2A) print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): real_A = next(self.trainA_loader) real_B = next(self.trainB_loader) real_A = np.array([real_A[0].reshape(3, 256, 256) ]).astype("float32") real_B = np.array([real_B[0].reshape(3, 256, 256) ]).astype("float32") real_A = to_variable(real_A) real_B = to_variable(real_B) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): real_A = next(self.testA_loader()) real_B = next(self.testB_loader()) real_A = np.array([real_A[0].reshape(3, 256, 256) ]).astype("float32") real_B = np.array([real_B[0].reshape(3, 256, 256) ]).astype("float32") real_A = to_variable(real_A) real_B = to_variable(real_B) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: fluid.save_dygraph( self.genA2B.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/genA2B")) fluid.save_dygraph( self.genB2A.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/genB2A")) fluid.save_dygraph( self.disGA.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disGA")) fluid.save_dygraph( self.disGB.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disGB")) fluid.save_dygraph( self.disLA.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disLA")) fluid.save_dygraph( self.disLB.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/disLB")) fluid.save_dygraph( self.D_optim.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/D_optim")) fluid.save_dygraph( self.G_optim.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/G_optim")) fluid.save_dygraph( self.genA2B.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/D_optim")) fluid.save_dygraph( self.genB2A.state_dict(), os.path.join(self.result_dir, self.dataset + "/latest/new/G_optim"))