def max_z_for_gen(args, virtual_gen, prev_gen, prev_cls, z_mu, z_var, z_t, gradient_steps=10, budget=10): """ retrieve most interfered samples for the vae branch :param new_net: :param mu: :param logvar: :param z_t: :param gradient_steps: :param budget: :return: """ z_new_max = None for i in range(budget): with torch.no_grad(): if args.mir_init_prior: z_new = prev_gen.prior.sample( (z_mu.shape[0], )).to(args.device) else: z_new = prev_gen.reparameterize(z_mu, z_var) for j in range(gradient_steps): z_new.requires_grad = True x_new = prev_gen.decode(z_new) prev_x_mean, prev_z_mu, prev_z_var, prev_ldj, prev_z0, prev_zk = \ prev_gen(x_new) _, prev_rec, prev_kl, _ = calculate_loss(prev_x_mean, x_new, prev_z_mu, \ prev_z_var, prev_z0, prev_zk, prev_ldj, args, beta=1) virtual_x_mean, virtual_z_mu, virtual_z_var, virtual_ldj, virtual_z0, virtual_zk = \ virtual_gen(x_new) _, virtual_rec, virtual_kl, _ = calculate_loss(virtual_x_mean, x_new, virtual_z_mu, \ virtual_z_var, virtual_z0, virtual_zk, virtual_ldj, args, beta=1) #TODO(warning, KL can explode) # maximise the interference KL = 0 if args.gen_kl_coeff > 0.: KL = virtual_kl - prev_kl REC = 0 if args.gen_rec_coeff > 0.: REC = virtual_rec - prev_rec # the predictions from the two models should be confident ENT = 0 if args.gen_ent_coeff > 0.: y_pre = prev_cls(x_new) ENT = cross_entropy(y_pre, y_pre) #TODO(should we do the args.curr_entropy thing?) DIV = 0 # the new found samples samples should be differnt from each others if args.gen_div_coeff > 0.: for found_z_i in range(i): DIV += F.mse_loss( z_new, z_new_max[found_z_i * z_new.size(0):found_z_i * z_new.size(0) + z_new.size(0)]) / (i) # (NEW) stay on gaussian shell loss: SHELL = 0 if args.gen_shell_coeff > 0.: SHELL = mse( torch.norm(z_new, 2, dim=1), torch.ones_like(torch.norm(z_new, 2, dim=1)) * np.sqrt(args.z_size)) gain = args.gen_kl_coeff * KL + \ args.gen_rec_coeff * REC + \ -args.gen_ent_coeff * ENT + \ args.gen_div_coeff * DIV + \ -args.gen_shell_coeff * SHELL z_g = torch.autograd.grad(gain, z_new)[0] z_new = (z_new + 1 * z_g).detach() if z_new_max is None: z_new_max = z_new.clone() else: z_new_max = torch.cat([z_new_max, z_new.clone()]) z_new_max.require_grad = False return z_new_max
args.beta = min([(sample_amt) / max([args.warmup, 1.]), args.max_beta]) #------ Train Generator ------# #------------------------------- # Begin Generator Iteration Loop for it in range(args.gen_iters): x_mean, z_mu, z_var, ldj, z0, zk = gen(data) gen_loss, rec, kl, _ = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=args.beta) tot_gen_loss = 0 + gen_loss if task > 0 and args.gen_method != 'no_rehearsal': if it == 0 or not args.reuse_samples: if args.gen_method == 'rand_gen': mem_x = prev_gen.generate(args.batch_size * args.n_mem).detach()
def retrieve_gen_for_gen(args, x, gen, prev_gen, prev_cls): grad_vector = get_grad_vector(args, gen.parameters, gen.grad_dims) virtual_gen = get_future_step_parameters(gen, grad_vector, gen.grad_dims, args.lr) _, z_mu, z_var, _, _, _ = prev_gen(x) z_new_max = None for i in range(args.n_mem): with torch.no_grad(): if args.mir_init_prior: z_new = prev_gen.prior.sample( (z_mu.shape[0], )).to(args.device) else: z_new = prev_gen.reparameterize(z_mu, z_var) for j in range(args.mir_iters): z_new.requires_grad = True x_new = prev_gen.decode(z_new) prev_x_mean, prev_z_mu, prev_z_var, prev_ldj, prev_z0, prev_zk = \ prev_gen(x_new) _, prev_rec, prev_kl, _ = calculate_loss(prev_x_mean, x_new, prev_z_mu, \ prev_z_var, prev_z0, prev_zk, prev_ldj, args, beta=1) virtual_x_mean, virtual_z_mu, virtual_z_var, virtual_ldj, virtual_z0, virtual_zk = \ virtual_gen(x_new) _, virtual_rec, virtual_kl, _ = calculate_loss(virtual_x_mean, x_new, virtual_z_mu, \ virtual_z_var, virtual_z0, virtual_zk, virtual_ldj, args, beta=1) #TODO(warning, KL can explode) # maximise the interference KL = 0 if args.gen_kl_coeff > 0.: KL = virtual_kl - prev_kl REC = 0 if args.gen_rec_coeff > 0.: REC = virtual_rec - prev_rec # the predictions from the two models should be confident ENT = 0 if args.gen_ent_coeff > 0.: y_pre = prev_cls(x_new) ENT = cross_entropy(y_pre, y_pre) #TODO(should we do the args.curr_entropy thing?) DIV = 0 # the new found samples samples should be differnt from each others if args.gen_div_coeff > 0.: for found_z_i in range(i): DIV += F.mse_loss( z_new, z_new_max[found_z_i * z_new.size(0):found_z_i * z_new.size(0) + z_new.size(0)]) / (i) # (NEW) stay on gaussian shell loss: SHELL = 0 if args.gen_shell_coeff > 0.: SHELL = mse( torch.norm(z_new, 2, dim=1), torch.ones_like(torch.norm(z_new, 2, dim=1)) * np.sqrt(args.z_size)) gain = args.gen_kl_coeff * KL + \ args.gen_rec_coeff * REC + \ -args.gen_ent_coeff * ENT + \ args.gen_div_coeff * DIV + \ -args.gen_shell_coeff * SHELL z_g = torch.autograd.grad(gain, z_new)[0] z_new = (z_new + 1 * z_g).detach() if z_new_max is None: z_new_max = z_new.clone() else: z_new_max = torch.cat([z_new_max, z_new.clone()]) z_new_max.require_grad = False if np.isnan(z_new_max.to('cpu').numpy()).any(): mir_worked = 0 mem_x = prev_gen.generate(args.batch_size * args.n_mem).detach() else: mem_x = prev_gen.decode(z_new_max).detach() mir_worked = 1 return mem_x, mir_worked