class fAnoGAN: @monkey_patch_fn_args_as_config def __init__( self, input_shape, lr=1e-4, critic_iters=1, gen_iters=5, n_epochs=10, gp_lambda=10, z_dim=512, print_every_iter=20, plot_every_epoch=1, log_dir=None, load_path=None, logger="visdom", data_dir=None, use_encoder=True, enocoder_feature_weight=1e-4, encoder_discr_weight=0.0, ): self.plot_every_epoch = plot_every_epoch self.print_every_iter = print_every_iter self.gp_lambda = gp_lambda self.n_epochs = n_epochs self.gen_iters = gen_iters self.critic_iters = critic_iters self.size = input_shape[2] self.batch_size = input_shape[0] self.input_shape = input_shape self.z_dim = z_dim self.logger = logger self.data_dir = data_dir self.use_encoder = use_encoder self.enocoder_feature_weight = enocoder_feature_weight self.encoder_discr_weight = encoder_discr_weight log_dict = {} if logger is not None: log_dict = { 0: (logger), } self.tx = PytorchExperimentStub( name="fanogan", base_dir=log_dir, config=fn_args_as_config, loggers=log_dict, ) cuda_available = torch.cuda.is_available() self.device = torch.device("cuda" if cuda_available else "cpu") self.n_image_channels = input_shape[1] self.gen = IWGenerator(self.size, z_dim=z_dim, n_image_channels=self.n_image_channels) self.dis = IWDiscriminator(self.size, n_image_channels=self.n_image_channels) self.gen.apply(weights_init) self.dis.apply(weights_init) self.optimizer_G = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(self.dis.parameters(), lr=lr, betas=(0.5, 0.999)) self.gen = self.gen.to(self.device) self.dis = self.dis.to(self.device) if self.use_encoder: self.enc = IWEncoder(self.size, z_dim=z_dim, n_image_channels=self.n_image_channels) self.enc.apply(weights_init) self.enc = self.enc.to(self.device) self.optimizer_E = torch.optim.Adam(self.enc.parameters(), lr=lr, betas=(0.5, 0.999)) self.z = torch.randn(self.batch_size, z_dim).to(self.device) if load_path is not None: PytorchExperimentLogger.load_model_static( self.dis, os.path.join(load_path, "dis_final.pth")) PytorchExperimentLogger.load_model_static( self.gen, os.path.join(load_path, "gen_final.pth")) if self.use_encoder: try: pass # PytorchExperimentLogger.load_model_static(self.enc, os.path.join(load_path, "enc_final.pth")) except Exception: warnings.warn("Could not find an Encoder in the directory") time.sleep(5) def train(self): train_loader = get_numpy2d_dataset( base_dir=self.data_dir, num_processes=16, pin_memory=False, batch_size=self.batch_size, mode="train", target_size=self.size, slice_offset=10, ) print("Training GAN...") for epoch in range(self.n_epochs): # for epoch in range(0): data_loader_ = tqdm(enumerate(train_loader)) for i, batch in data_loader_: batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01 real_imgs = batch.to(self.device) # --------------------- # Train Discriminator # --------------------- # disc_cost = [] # w_dist = [] if i % self.critic_iters == 0: self.optimizer_G.zero_grad() self.optimizer_D.zero_grad() batch_size_curr = real_imgs.shape[0] self.z.normal_() fake_imgs = self.gen(self.z[:batch_size_curr]) real_validity = self.dis(real_imgs) fake_validity = self.dis(fake_imgs) gradient_penalty = self.calc_gradient_penalty( self.dis, real_imgs, fake_imgs, batch_size_curr, self.size, self.device, self.gp_lambda, n_image_channels=self.n_image_channels, ) d_loss = -torch.mean(real_validity) + torch.mean( fake_validity) + self.gp_lambda * gradient_penalty d_loss.backward() self.optimizer_D.step() # disc_cost.append(d_loss.item()) w_dist = (-torch.mean(real_validity) + torch.mean(fake_validity)).item() # ----------------- # Train Generator # ----------------- # gen_cost = [] if i % self.gen_iters == 0: self.optimizer_G.zero_grad() self.optimizer_D.zero_grad() batch_size_curr = self.batch_size fake_imgs = self.gen(self.z) fake_validity = self.dis(fake_imgs) g_loss = -torch.mean(fake_validity) g_loss.backward() self.optimizer_G.step() # gen_cost.append(g_loss.item()) if i % self.print_every_iter == 0: status_str = ( f"Train Epoch: {epoch} [{i}/{len(train_loader)} " f" ({100.0 * i / len(train_loader):.0f}%)] Dis: " f"{d_loss.item() / batch_size_curr:.6f} vs Gen: " f"{g_loss.item() / batch_size_curr:.6f} (W-Dist: {w_dist / batch_size_curr:.6f})" ) data_loader_.set_description_str(status_str) # print(f"[Epoch {epoch}/{self.n_epochs}] [Batch {i}/{len(train_loader)}]") # print(d_loss.item(), g_loss.item()) cnt = epoch * len(train_loader) + i self.tx.add_result(d_loss.item(), name="trainDisCost", tag="DisVsGen", counter=cnt) self.tx.add_result(g_loss.item(), name="trainGenCost", tag="DisVsGen", counter=cnt) self.tx.add_result(w_dist, "wasserstein_distance", counter=cnt) self.tx.l[0].show_image_grid( fake_imgs.reshape(batch_size_curr, self.n_image_channels, self.size, self.size), "GeneratedImages", image_args={"normalize": True}, ) self.tx.save_model(self.dis, "dis_final") self.tx.save_model(self.gen, "gen_final") self.gen.train(True) self.dis.train(True) if not self.use_encoder: time.sleep(10) return weight_features = self.enocoder_feature_weight weight_disc = self.encoder_discr_weight print("Training Encoder...") for epoch in range(self.n_epochs // 2): data_loader_ = tqdm(enumerate(train_loader)) for i, batch in data_loader_: batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01 real_img = batch.to(self.device) batch_size_curr = real_img.shape[0] self.optimizer_G.zero_grad() self.optimizer_D.zero_grad() self.optimizer_E.zero_grad() z = self.enc(real_img) recon_img = self.gen(z) _, img_feats = self.dis.forward_last_feature(real_img) disc_loss, recon_feats = self.dis.forward_last_feature( recon_img) recon_img = recon_img.reshape(batch_size_curr, self.n_image_channels, self.size, self.size) loss_img = self.mse(real_img, recon_img) loss_feat = self.mse(img_feats, recon_feats) * weight_features disc_loss = -torch.mean(disc_loss) * weight_disc loss = loss_img + loss_feat + disc_loss loss.backward() self.optimizer_E.step() if i % self.print_every_iter == 0: status_str = ( f"[Epoch {epoch}/{self.n_epochs // 2}] [Batch {i}/{len(train_loader)}] Loss:{loss:.06f}" ) data_loader_.set_description_str(status_str) cnt = epoch * len(train_loader) + i self.tx.add_result(loss.item(), name="EncoderLoss", counter=cnt) self.tx.l[0].show_image_grid( real_img.reshape(batch_size_curr, self.n_image_channels, self.size, self.size), "RealImages", image_args={"normalize": True}, ) self.tx.l[0].show_image_grid( recon_img.reshape(batch_size_curr, self.n_image_channels, self.size, self.size), "ReconImages", image_args={"normalize": True}, ) self.tx.save_model(self.enc, "enc_final") self.enc.train(False) time.sleep(10) def score_sample(self, np_array): orig_shape = np_array.shape to_transforms = torch.nn.Upsample( (self.input_shape[2], self.input_shape[3]), mode="bilinear") data_tensor = torch.from_numpy(np_array).float() data_tensor = to_transforms(data_tensor[None])[0] slice_scores = [] for i in range(ceil(orig_shape[0] / self.batch_size)): batch = data_tensor[i * self.batch_size:(i + 1) * self.batch_size].unsqueeze(1) batch = batch * 2 - 1 real_imgs = batch.to(self.device) batch_size_curr = real_imgs.shape[0] if self.use_encoder: z = self.enc(real_imgs) else: z = self.backprop_to_nearest_z(real_imgs) pseudo_img_recon = self.gen(z) pseudo_img_recon = pseudo_img_recon.reshape( batch_size_curr, self.n_image_channels, self.size, self.size) img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True) loss = torch.sum(img_diff, dim=(1, 2, 3)).detach() slice_scores += loss.cpu().tolist() return np.max(slice_scores) def score_pixels(self, np_array): orig_shape = np_array.shape to_transforms = torch.nn.Upsample( (self.input_shape[2], self.input_shape[3]), mode="bilinear") from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]), mode="bilinear") data_tensor = torch.from_numpy(np_array).float() data_tensor = to_transforms(data_tensor[None])[0] target_tensor = torch.zeros_like(data_tensor) for i in range(ceil(orig_shape[0] / self.batch_size)): batch = data_tensor[i * self.batch_size:(i + 1) * self.batch_size].unsqueeze(1) batch = batch * 2 - 1 real_imgs = batch.to(self.device) batch_size_curr = real_imgs.shape[0] if self.use_encoder: z = self.enc(real_imgs) else: z = self.backprop_to_nearest_z(real_imgs) pseudo_img_recon = self.gen(z) pseudo_img_recon = pseudo_img_recon.reshape( batch_size_curr, self.n_image_channels, self.size, self.size) img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True) loss = img_diff[:, 0, :] target_tensor[i * self.batch_size:(i + 1) * self.batch_size] = loss.cpu() target_tensor = from_transforms(target_tensor[None])[0] return target_tensor.detach().numpy() def backprop_to_nearest_z(self, real_imgs): batch_size_curr = real_imgs.shape[0] z = torch.randn(batch_size_curr, self.z_dim).to(self.device).normal_() z.requires_grad = True # optimizer_z = torch.optim.LBFGS([z], lr=0.02) optimizer_z = torch.optim.Adam([z], lr=0.002) # optimizer_z = torch.optim.RMSprop([z], lr=0.05) for i in range(200): def closure(): self.gen.zero_grad() optimizer_z.zero_grad() pseudo_img_recon = self.gen(z) _, img_feats = self.dis.forward_last_feature(real_imgs) disc_loss, recon_feats = self.dis.forward_last_feature( pseudo_img_recon) pseudo_img_recon = pseudo_img_recon.reshape( batch_size_curr, self.n_image_channels, self.size, self.size) disc_loss = torch.mean(disc_loss) imgs_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs)) feats_diff = torch.mean(torch.abs(img_feats - recon_feats)) loss = imgs_diff - disc_loss * 0.001 # + feats_diff loss.backward() return loss optimizer_z.step(closure) return z.detach() def score(self, batch): real_imgs = batch.to(self.device).float() z = self.enc(real_imgs) batch_size_curr = real_imgs.shape[0] # z = torch.randn(batch_size_curr, self.z_dim).to(self.device).normal_() # z.requires_grad = True # # optimizer_z = torch.optim.LBFGS([z], lr=0.02) # optimizer_z = torch.optim.Adam([z], lr=0.002) # # optimizer_z = torch.optim.RMSprop([z], lr=0.05) # # cn = dict(tr=0) # # self.tx.vlog.show_image_grid(real_imgs, "RealImages", # image_args={"normalize": True}) # # for i in range(200): # def closure(): # self.gen.zero_grad() # optimizer_z.zero_grad() # # pseudo_img_recon = self.gen(z) # # _, img_feats = self.dis.forward_last_feature(real_imgs) # disc_loss, recon_feats = self.dis.forward_last_feature(pseudo_img_recon) # # pseudo_img_recon = pseudo_img_recon.reshape(batch_size_curr, self.n_image_channels, self.size, self.size) # disc_loss = torch.mean(disc_loss) # # imgs_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs)) # feats_diff = torch.mean(torch.abs(img_feats - recon_feats)) # loss = imgs_diff - disc_loss * 0.001 # + feats_diff # # loss.backward() # # optimizer_z.step() # # # # if cn['tr'] % 20 == 0: # # pseudo_img_recon = pseudo_img_recon.clamp(-1.5, 1.5) # self.tx.vlog.show_image_grid(pseudo_img_recon, "PseudoImages", # image_args={"normalize": True}) # self.tx.vlog.show_image_grid(torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True), # "DiffImages", image_args={"normalize": True}) # # # # tx.add_result(disc_loss.item() * 0.001, name="DiscLoss", tag="AnoIter") # # tx.add_result(imgs_diff.item(), name="ImgsDiff", tag="AnoIter") # # tx.add_result(torch.mean(torch.pow(z, 2)).item(), name="ZDevi", tag="AnoIter") # # # # cn['tr'] += 1 # # return loss # # optimizer_z.step(closure) # # # time.sleep(1) # # print(i) # pseudo_img_recon = self.gen(z) pseudo_img_recon = pseudo_img_recon.reshape(batch_size_curr, self.n_image_channels, self.size, self.size) img_diff = torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True) img_scores = torch.sum(img_diff, dim=(1, 2, 3)).detach().tolist() pixel_scores = img_diff.flatten().detach().tolist() self.tx.vlog.show_image_grid(pseudo_img_recon, "PseudoImages", image_args={"normalize": True}) self.tx.vlog.show_image_grid( torch.mean(torch.abs(pseudo_img_recon - real_imgs), dim=1, keepdim=True), "DiffImages", image_args={"normalize": True}, ) # print("One Down") return img_scores, pixel_scores @staticmethod def mse(x, y): return torch.mean(torch.pow(x - y, 2)) @staticmethod def calc_gradient_penalty(netD, real_data, fake_data, batch_size, dim, device, gp_lambda, n_image_channels=3): alpha = torch.rand(batch_size, 1) alpha = alpha.expand(batch_size, int(real_data.nelement() / batch_size)).contiguous() alpha = alpha.view(batch_size, n_image_channels, dim, dim) alpha = alpha.to(device) fake_data = fake_data.view(batch_size, n_image_channels, dim, dim) interpolates = alpha * real_data.detach() + ( (1 - alpha) * fake_data.detach()) interpolates = interpolates.to(device) interpolates.requires_grad_(True) disc_interpolates = netD(interpolates) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ( (gradients.norm(2, dim=1) - 1)**2).mean() * gp_lambda return gradient_penalty def print(self, *args): print(*args) self.tx.print(*args) def log_result(self, val, key=None): self.tx.print(key, val) self.tx.add_result_without_epoch(val, key)
class ceVAE: @monkey_patch_fn_args_as_config def __init__( self, input_shape, lr=1e-4, n_epochs=20, z_dim=512, model_feature_map_sizes=(16, 64, 256, 1024), use_geco=False, beta=0.01, ce_factor=0.5, score_mode="combi", load_path=None, log_dir=None, logger="visdom", print_every_iter=100, data_dir=None, ): self.score_mode = score_mode self.ce_factor = ce_factor self.beta = beta self.print_every_iter = print_every_iter self.n_epochs = n_epochs self.batch_size = input_shape[0] self.z_dim = z_dim self.use_geco = use_geco self.input_shape = input_shape self.logger = logger self.data_dir = data_dir log_dict = {} if logger is not None: log_dict = { 0: (logger), } self.tx = PytorchExperimentStub( name="cevae", base_dir=log_dir, config=fn_args_as_config, loggers=log_dict, ) cuda_available = torch.cuda.is_available() self.device = torch.device("cuda" if cuda_available else "cpu") self.model = VAE(input_size=input_shape[1:], z_dim=z_dim, fmap_sizes=model_feature_map_sizes).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.vae_loss_ema = 1 self.theta = 1 if load_path is not None: PytorchExperimentLogger.load_model_static( self.model, os.path.join(load_path, "vae_final.pth")) time.sleep(5) def train(self): train_loader = get_numpy2d_dataset( base_dir=self.data_dir, num_processes=16, pin_memory=False, batch_size=self.batch_size, mode="train", target_size=self.input_shape[2], ) val_loader = get_numpy2d_dataset( base_dir=self.data_dir, num_processes=8, pin_memory=False, batch_size=self.batch_size, mode="val", target_size=self.input_shape[2], ) for epoch in range(self.n_epochs): self.model.train() train_loss = 0 print("Start epoch") data_loader_ = tqdm(enumerate(train_loader)) for batch_idx, data in data_loader_: data = data * 2 - 1 self.optimizer.zero_grad() inpt = data.to(self.device) ### VAE Part loss_vae = 0 if self.ce_factor < 1: x_rec_vae, z_dist, = self.model(inpt) kl_loss = 0 if self.beta > 0: kl_loss = self.kl_loss_fn(z_dist) * self.beta rec_loss_vae = self.rec_loss_fn(x_rec_vae, inpt) loss_vae = kl_loss + rec_loss_vae * self.theta ### CE Part loss_ce = 0 if self.ce_factor > 0: ce_tensor = get_square_mask( data.shape, square_size=(0, np.max(self.input_shape[2:]) // 2), noise_val=(torch.min(data).item(), torch.max(data).item()), n_squares=(0, 3), ) ce_tensor = torch.from_numpy(ce_tensor).float() inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, data) inpt_noisy = inpt_noisy.to(self.device) x_rec_ce, _ = self.model(inpt_noisy) rec_loss_ce = self.rec_loss_fn(x_rec_ce, inpt) loss_ce = rec_loss_ce loss = (1.0 - self.ce_factor) * loss_vae + self.ce_factor * loss_ce if self.use_geco and self.ce_factor < 1: g_goal = 0.1 g_lr = 1e-4 self.vae_loss_ema = ( 1.0 - 0.9) * rec_loss_vae + 0.9 * self.vae_loss_ema self.theta = self.geco_beta_update(self.theta, self.vae_loss_ema, g_goal, g_lr, speedup=2) if torch.isnan(loss): print("A wild NaN occurred") continue loss.backward() self.optimizer.step() train_loss += loss.item() if batch_idx % self.print_every_iter == 0: status_str = ( f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} " f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: " f"{loss.item() / len(inpt):.6f}") data_loader_.set_description_str(status_str) cnt = epoch * len(train_loader) + batch_idx if self.ce_factor < 1: self.tx.l[0].show_image_grid( inpt, name="Input-VAE", image_args={"normalize": True}) self.tx.l[0].show_image_grid( x_rec_vae, name="Output-VAE", image_args={"normalize": True}) if self.beta > 0: self.tx.add_result(torch.mean(kl_loss).item(), name="Kl-loss", tag="Losses", counter=cnt) self.tx.add_result(torch.mean(rec_loss_vae).item(), name="Rec-loss", tag="Losses", counter=cnt) self.tx.add_result(loss_vae.item(), name="Train-loss", tag="Losses", counter=cnt) if self.ce_factor > 0: self.tx.l[0].show_image_grid( inpt_noisy, name="Input-CE", image_args={"normalize": True}) self.tx.l[0].show_image_grid( x_rec_ce, name="Output-CE", image_args={"normalize": True}) print( f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}" ) self.model.eval() val_loss = 0 with torch.no_grad(): data_loader_ = tqdm(enumerate(val_loader)) for i, data in data_loader_: data = data * 2 - 1 inpt = data.to(self.device) x_rec, z_dist = self.model(inpt, sample=False) kl_loss = 0 if self.beta > 0: kl_loss = self.kl_loss_fn(z_dist) * self.beta rec_loss = self.rec_loss_fn(x_rec, inpt) loss = kl_loss + rec_loss * self.theta val_loss += loss.item() self.tx.add_result(val_loss / len(val_loader), name="Val-Loss", tag="Losses", counter=(epoch + 1) * len(train_loader)) print( f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}" ) self.tx.save_model(self.model, "vae_final") time.sleep(10) def score_sample(self, np_array): orig_shape = np_array.shape to_transforms = torch.nn.Upsample( (self.input_shape[2], self.input_shape[3]), mode="bilinear") data_tensor = torch.from_numpy(np_array).float() data_tensor = to_transforms(data_tensor[None])[0] slice_scores = [] for i in range(ceil(orig_shape[0] / self.batch_size)): batch = data_tensor[i * self.batch_size:(i + 1) * self.batch_size].unsqueeze(1) batch = batch * 2 - 1 with torch.no_grad(): inpt = batch.to(self.device).float() x_rec, z_dist = self.model(inpt, sample=False) kl_loss = self.kl_loss_fn(z_dist, sum_samples=False) rec_loss = self.rec_loss_fn(x_rec, inpt, sum_samples=False) img_scores = kl_loss * self.beta + rec_loss * self.theta slice_scores += img_scores.cpu().tolist() return np.max(slice_scores) def score_pixels(self, np_array): orig_shape = np_array.shape to_transforms = torch.nn.Upsample( (self.input_shape[2], self.input_shape[3]), mode="bilinear") from_transforms = torch.nn.Upsample((orig_shape[1], orig_shape[2]), mode="bilinear") data_tensor = torch.from_numpy(np_array).float() data_tensor = to_transforms(data_tensor[None])[0] target_tensor = torch.zeros_like(data_tensor) for i in range(ceil(orig_shape[0] / self.batch_size)): batch = data_tensor[i * self.batch_size:(i + 1) * self.batch_size].unsqueeze(1) batch = batch * 2 - 1 inpt = batch.to(self.device).float() x_rec, z_dist = self.model(inpt, sample=False) if self.score_mode == "combi": rec = torch.pow((x_rec - inpt), 2).detach().cpu() rec = torch.mean(rec, dim=1, keepdim=True) def __err_fn(x): x_r, z_d = self.model(x, sample=False) loss = self.kl_loss_fn(z_d) return loss loss_grad_kl = (get_smooth_image_gradient( model=self.model, inpt=inpt, err_fn=__err_fn, grad_type="vanilla", n_runs=2).detach().cpu()) loss_grad_kl = torch.mean(loss_grad_kl, dim=1, keepdim=True) pixel_scores = smooth_tensor(normalize(loss_grad_kl), kernel_size=8) * rec elif self.score_mode == "rec": rec = torch.pow((x_rec - inpt), 2).detach().cpu() rec = torch.mean(rec, dim=1, keepdim=True) pixel_scores = rec elif self.score_mode == "grad": def __err_fn(x): x_r, z_d = self.model(x, sample=False) kl_loss_ = self.kl_loss_fn(z_d) rec_loss_ = self.rec_loss_fn(x_r, x) loss_ = kl_loss_ * self.beta + rec_loss_ * self.theta return torch.mean(loss_) loss_grad_kl = (get_smooth_image_gradient( model=self.model, inpt=inpt, err_fn=__err_fn, grad_type="vanilla", n_runs=2).detach().cpu()) loss_grad_kl = torch.mean(loss_grad_kl, dim=1, keepdim=True) pixel_scores = smooth_tensor(normalize(loss_grad_kl), kernel_size=8) self.tx.elog.show_image_grid(inpt, name="Input", image_args={"normalize": True}, n_iter=i) self.tx.elog.show_image_grid(x_rec, name="Output", image_args={"normalize": True}, n_iter=i) self.tx.elog.show_image_grid(pixel_scores, name="Scores", image_args={"normalize": True}, n_iter=i) target_tensor[i * self.batch_size:(i + 1) * self.batch_size] = pixel_scores.detach().cpu()[:, 0, :] target_tensor = from_transforms(target_tensor[None])[0] return target_tensor.detach().numpy() @staticmethod def load_trained_model(model, tx, path): tx.elog.load_model_static(model=model, model_file=path) @staticmethod def kl_loss_fn(z_post, sum_samples=True, correct=False): z_prior = dist.Normal(0, 1.0) kl_div = dist.kl_divergence(z_post, z_prior) if correct: kl_div = torch.sum(kl_div, dim=(1, 2, 3)) else: kl_div = torch.mean(kl_div, dim=(1, 2, 3)) if sum_samples: return torch.mean(kl_div) else: return kl_div @staticmethod def rec_loss_fn(recon_x, x, sum_samples=True, correct=False): if correct: x_dist = dist.Laplace(recon_x, 1.0) log_p_x_z = x_dist.log_prob(x) log_p_x_z = torch.sum(log_p_x_z, dim=(1, 2, 3)) else: log_p_x_z = -torch.abs(recon_x - x) log_p_x_z = torch.mean(log_p_x_z, dim=(1, 2, 3)) if sum_samples: return -torch.mean(log_p_x_z) else: return -log_p_x_z @staticmethod def get_inpt_grad(model, inpt, err_fn): model.zero_grad() inpt = inpt.detach() inpt.requires_grad = True err = err_fn(inpt) err.backward() grad = inpt.grad.detach() model.zero_grad() return torch.abs(grad.detach()) @staticmethod def geco_beta_update(beta, error_ema, goal, step_size, min_clamp=1e-10, max_clamp=1e4, speedup=None): constraint = (error_ema - goal).detach() if speedup is not None and constraint > 0.0: beta = beta * torch.exp(speedup * step_size * constraint) else: beta = beta * torch.exp(step_size * constraint) if min_clamp is not None: beta = np.max((beta.item(), min_clamp)) if max_clamp is not None: beta = np.min((beta.item(), max_clamp)) return beta @staticmethod def get_ema(new, old, alpha): if old is None: return new return (1.0 - alpha) * new + alpha * old def print(self, *args): print(*args) self.tx.print(*args) def log_result(self, val, key=None): self.tx.print(key, val) self.tx.add_result_without_epoch(val, key)