def _build(self): # inference self.infer_obs_t = nn.Variable((1,) + self.obs_shape) with nn.parameter_scope('trainable'): self.infer_policy_t = policy_network(self.infer_obs_t, self.action_size, 'actor') # training self.obss_t = nn.Variable((self.batch_size,) + self.obs_shape) self.acts_t = nn.Variable((self.batch_size, self.action_size)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size,) + self.obs_shape) self.ters_tp1 = nn.Variable((self.batch_size, 1)) # critic training with nn.parameter_scope('trainable'): q_t = q_network(self.obss_t, self.acts_t, 'critic') with nn.parameter_scope('target'): policy_tp1 = policy_network(self.obss_tp1, self.action_size, 'actor') q_tp1 = q_network(self.obss_tp1, policy_tp1, 'critic') y = self.rews_tp1 + self.gamma * q_tp1 * (1.0 - self.ters_tp1) self.critic_loss = F.mean(F.squared_error(q_t, y)) # actor training with nn.parameter_scope('trainable'): policy_t = policy_network(self.obss_t, self.action_size, 'actor') q_t_with_actor = q_network(self.obss_t, policy_t, 'critic') self.actor_loss = -F.mean(q_t_with_actor) # get neural network parameters with nn.parameter_scope('trainable'): with nn.parameter_scope('critic'): critic_params = nn.get_parameters() with nn.parameter_scope('actor'): actor_params = nn.get_parameters() # setup optimizers self.critic_solver = S.Adam(self.critic_lr) self.critic_solver.set_parameters(critic_params) self.actor_solver = S.Adam(self.actor_lr) self.actor_solver.set_parameters(actor_params) with nn.parameter_scope('trainable'): trainable_params = nn.get_parameters() with nn.parameter_scope('target'): target_params = nn.get_parameters() # build target update update_targets = [] sync_targets = [] for key, src in trainable_params.items(): dst = target_params[key] updated_dst = (1.0 - self.tau) * dst + self.tau * src update_targets.append(F.assign(dst, updated_dst)) sync_targets.append(F.assign(dst, src)) self.update_target_expr = F.sink(*update_targets) self.sync_target_expr = F.sink(*sync_targets)
def test_assign_forward_backward(seed, ctx, func_name): rng = np.random.RandomState(seed) dst = nn.Variable((2, 3, 4), need_grad=True) src = nn.Variable((2, 3, 4), need_grad=True) assign = F.assign(dst, src) src.d = rng.rand(2, 3, 4) assign.forward() # destination variable should be equal to source variable assert_allclose(dst.d, src.d) # output variable of assign function should be equal to soure variable assert_allclose(assign.d, src.d) dummy = assign + rng.rand() dst.grad.zero() src.grad.zero() dummy.forward() dummy.backward() # gradients at destination are identical to gradients at assign operation assert not np.all(dst.g == np.zeros((2, 3, 4))) assert np.all(dst.g == assign.g) assert np.all(src.g == np.zeros((2, 3, 4))) # check accum=False assign.grad.zero() dst.g = rng.rand(2, 3, 4) f = assign.parent f.forward([dst, src], [assign]) f.backward([dst, src], [assign], accum=[False]) assert np.all(dst.g == assign.g) assert np.all(src.g == np.zeros((2, 3, 4)))
def build_static_graph(self): real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size, self.img_size)) noises = [ F.randn(shape=(self.batch_size, self.config['latent_dim'])) for _ in range(2) ] if self.config['regularize_gen']: fake_img, dlatents = self.generator(self.batch_size, noises, return_latent=True) else: fake_img = self.generator(self.batch_size, noises) fake_img_test = self.generator_ema(self.batch_size, noises) gen_loss = gen_nonsaturating_loss(self.discriminator(fake_img)) fake_disc_out = self.discriminator(fake_img) real_disc_out = self.discriminator(real_img) disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out) var_name_list = [ 'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss', 'fake_disc_out', 'real_disc_out', 'fake_img_test' ] var_list = [ real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out, real_disc_out, fake_img_test ] if self.config['regularize_gen']: dlatents.need_grad = True mean_path_length = nn.Variable() pl_reg, path_mean, _ = gen_path_regularize( fake_img=fake_img, latents=dlatents, mean_path_length=mean_path_length) path_mean_update = F.assign(mean_path_length, path_mean) path_mean_update.name = 'path_mean_update' pl_reg += 0 * path_mean_update gen_loss_reg = gen_loss + pl_reg var_name_list.append('gen_loss_reg') var_list.append(gen_loss_reg) if self.config['regularize_disc']: real_img.need_grad = True real_disc_out = self.discriminator(real_img) disc_loss_reg = disc_loss + self.config[ 'r1_coeff'] * 0.5 * disc_r1_loss( real_disc_out, real_img) * self.config['disc_reg_step'] real_img.need_grad = False var_name_list.append('disc_loss_reg') var_list.append(disc_loss_reg) Parameters = namedtuple('Parameters', var_name_list) self.parameters = Parameters(*var_list)
def make_ema_updater(scope_ema, scope_cur, ema_decay): with nn.parameter_scope(scope_cur): params_cur = nn.get_parameters() with nn.parameter_scope(scope_ema): params_ema = nn.get_parameters() update_ema_list = [] for name in params_ema.keys(): params_ema_updated = ema_decay * \ params_ema[name] + (1.0 - ema_decay) * params_cur[name] update_ema_list.append(F.assign(params_ema[name], params_ema_updated)) return F.sink(*update_ema_list)
def ema_update(self): with nn.parameter_scope('Generator'): g_params = nn.get_parameters(grad_only=False) with nn.parameter_scope('GeneratorEMA'): g_ema_params = nn.get_parameters(grad_only=False) update_ema_list = [] for name in g_ema_params.keys(): params_ema_updated = self.gen_exp_weight * \ g_ema_params[name] + \ (1.0 - self.gen_exp_weight) * g_params[name] update_ema_list.append( F.assign(g_ema_params[name], params_ema_updated)) return F.sink(*update_ema_list)
def _build(self): # inference graph self.infer_obs_t = nn.Variable((1, ) + self.obs_shape) with nn.parameter_scope('trainable'): infer_dist = policy_network(self.infer_obs_t, self.action_size, 'actor') self.infer_act_t, _ = _squash_action(infer_dist) self.deterministic_act_t = infer_dist.mean() # training graph self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape) self.acts_t = nn.Variable((self.batch_size, self.action_size)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape) self.ters_tp1 = nn.Variable((self.batch_size, 1)) with nn.parameter_scope('trainable'): self.log_temp = get_parameter_or_create('temp', [1, 1], ConstantInitializer(0.0)) dist_t = policy_network(self.obss_t, self.action_size, 'actor') dist_tp1 = policy_network(self.obss_tp1, self.action_size, 'actor') squashed_act_t, log_prob_t = _squash_action(dist_t) squashed_act_tp1, log_prob_tp1 = _squash_action(dist_tp1) q1_t = q_network(self.obss_t, self.acts_t, 'critic/1') q2_t = q_network(self.obss_t, self.acts_t, 'critic/2') q1_t_with_actor = q_network(self.obss_t, squashed_act_t, 'critic/1') q2_t_with_actor = q_network(self.obss_t, squashed_act_t, 'critic/2') with nn.parameter_scope('target'): q1_tp1 = q_network(self.obss_tp1, squashed_act_tp1, 'critic/1') q2_tp1 = q_network(self.obss_tp1, squashed_act_tp1, 'critic/2') # q function loss q_tp1 = F.minimum2(q1_tp1, q2_tp1) entropy_tp1 = F.exp(self.log_temp) * log_prob_tp1 mask = (1.0 - self.ters_tp1) q_target = self.rews_tp1 + self.gamma * (q_tp1 - entropy_tp1) * mask q_target.need_grad = False q1_loss = 0.5 * F.mean(F.squared_error(q1_t, q_target)) q2_loss = 0.5 * F.mean(F.squared_error(q2_t, q_target)) self.critic_loss = q1_loss + q2_loss # policy function loss q_t = F.minimum2(q1_t_with_actor, q2_t_with_actor) entropy_t = F.exp(self.log_temp) * log_prob_t self.actor_loss = F.mean(entropy_t - q_t) # temperature loss temp_target = log_prob_t - self.action_size temp_target.need_grad = False self.temp_loss = -F.mean(F.exp(self.log_temp) * temp_target) # trainable parameters with nn.parameter_scope('trainable'): with nn.parameter_scope('critic'): critic_params = nn.get_parameters() with nn.parameter_scope('actor'): actor_params = nn.get_parameters() # target parameters with nn.parameter_scope('target/critic'): target_params = nn.get_parameters() # target update update_targets = [] sync_targets = [] for key, src in critic_params.items(): dst = target_params[key] updated_dst = (1.0 - self.tau) * dst + self.tau * src update_targets.append(F.assign(dst, updated_dst)) sync_targets.append(F.assign(dst, src)) self.update_target_expr = F.sink(*update_targets) self.sync_target_expr = F.sink(*sync_targets) # setup solvers self.critic_solver = S.Adam(self.critic_lr) self.critic_solver.set_parameters(critic_params) self.actor_solver = S.Adam(self.actor_lr) self.actor_solver.set_parameters(actor_params) self.temp_solver = S.Adam(self.temp_lr) self.temp_solver.set_parameters({'temp': self.log_temp})
def _build(self): # inference graph self.infer_obs_t = nn.Variable((1, ) + self.obs_shape) with nn.parameter_scope('trainable'): infer_dist = policy_network(self.infer_obs_t, self.action_size, 'actor') self.infer_act_t, _ = _squash_action(infer_dist) self.deterministic_act_t = infer_dist.mean() # training graph self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape) self.acts_t = nn.Variable((self.batch_size, self.action_size)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape) self.ters_tp1 = nn.Variable((self.batch_size, 1)) with nn.parameter_scope('trainable'): dist = policy_network(self.obss_t, self.action_size, 'actor') squashed_act_t, log_prob_t = _squash_action(dist) v_t = v_network(self.obss_t, 'value') q_t1 = q_network(self.obss_t, self.acts_t, 'critic/1') q_t2 = q_network(self.obss_t, self.acts_t, 'critic/2') q_t1_with_actor = q_network(self.obss_t, squashed_act_t, 'critic/1') q_t2_with_actor = q_network(self.obss_t, squashed_act_t, 'critic/2') with nn.parameter_scope('target'): v_tp1 = v_network(self.obss_tp1, 'value') # value loss q_t = F.minimum2(q_t1_with_actor, q_t2_with_actor) v_target = q_t - log_prob_t v_target.need_grad = False self.value_loss = 0.5 * F.mean(F.squared_error(v_t, v_target)) # q function loss scaled_rews_tp1 = self.rews_tp1 * self.reward_scale q_target = scaled_rews_tp1 + self.gamma * v_tp1 * (1.0 - self.ters_tp1) q_target.need_grad = False q1_loss = 0.5 * F.mean(F.squared_error(q_t1, q_target)) q2_loss = 0.5 * F.mean(F.squared_error(q_t2, q_target)) self.critic_loss = q1_loss + q2_loss # policy function loss mean_loss = 0.5 * F.mean(dist.mean()**2) logstd_loss = 0.5 * F.mean(F.log(dist.stddev())**2) policy_reg_loss = self.policy_reg * (mean_loss + logstd_loss) self.objective_loss = F.mean(log_prob_t - q_t) self.actor_loss = self.objective_loss + policy_reg_loss # trainable parameters with nn.parameter_scope('trainable'): with nn.parameter_scope('value'): value_params = nn.get_parameters() with nn.parameter_scope('critic'): critic_params = nn.get_parameters() with nn.parameter_scope('actor'): actor_params = nn.get_parameters() # target parameters with nn.parameter_scope('target/value'): target_params = nn.get_parameters() # target update update_targets = [] sync_targets = [] for key, src in value_params.items(): dst = target_params[key] updated_dst = (1.0 - self.tau) * dst + self.tau * src update_targets.append(F.assign(dst, updated_dst)) sync_targets.append(F.assign(dst, src)) self.update_target_expr = F.sink(*update_targets) self.sync_target_expr = F.sink(*sync_targets) # setup solvers self.value_solver = S.Adam(self.value_lr) self.value_solver.set_parameters(value_params) self.critic_solver = S.Adam(self.critic_lr) self.critic_solver.set_parameters(critic_params) self.actor_solver = S.Adam(self.actor_lr) self.actor_solver.set_parameters(actor_params)
def __call__(self, batch_size, style_noises, truncation_psi=1.0, return_latent=False, mixing_layer_index=None, dlatent_avg_beta=0.995): with nn.parameter_scope(self.global_scope): # normalize noise inputs for i in range(len(style_noises)): style_noises[i] = F.div2( style_noises[i], F.pow_scalar(F.add_scalar(F.mean(style_noises[i]**2., axis=1, keepdims=True), 1e-8, inplace=False), 0.5, inplace=False)) # get latent code w = [ mapping_network(style_noises[0], outmaps=self.mapping_network_dim, num_layers=self.mapping_network_num_layers) ] w += [ mapping_network(style_noises[1], outmaps=self.mapping_network_dim, num_layers=self.mapping_network_num_layers) ] dlatent_avg = nn.parameter.get_parameter_or_create( name="dlatent_avg", shape=(1, 512)) # Moving average update of dlatent_avg batch_avg = F.mean((w[0] + w[1]) * 0.5, axis=0, keepdims=True) update_op = F.assign( dlatent_avg, lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) update_op.name = 'dlatent_avg_update' dlatent_avg = F.identity(dlatent_avg) + 0 * update_op # truncation trick w = [lerp(dlatent_avg, _, truncation_psi) for _ in w] # generate output from generator constant_bc = nn.parameter.get_parameter_or_create( name="G_synthesis/4x4/Const/const", shape=(1, 512, 4, 4), initializer=np.random.randn(1, 512, 4, 4).astype(np.float32)) constant_bc = F.broadcast(constant_bc, (batch_size, ) + constant_bc.shape[1:]) if mixing_layer_index is None: mixing_layer_index_var = F.randint(1, len(self.resolutions) * 2, (1, )) else: mixing_layer_index_var = F.constant(val=mixing_layer_index, shape=(1, )) mixing_switch_var = F.clip_by_value( F.arange(0, len(self.resolutions) * 2) - mixing_layer_index_var, 0, 1) mixing_switch_var_re = F.reshape( mixing_switch_var, (1, mixing_switch_var.shape[0], 1), inplace=False) w0 = F.reshape(w[0], (batch_size, 1, w[0].shape[1]), inplace=False) w1 = F.reshape(w[1], (batch_size, 1, w[0].shape[1]), inplace=False) w_mixed = w0 * mixing_switch_var_re + \ w1 * (1 - mixing_switch_var_re) rgb_output = self.synthesis(w_mixed, constant_bc) if return_latent: return rgb_output, w_mixed else: return rgb_output
def sample_loop(self, model, shape, sampler, noise=None, dump_interval=-1, progress=False, without_auto_forward=False): """ Iteratively Sample data from model from t=T to t=0. T is specified as the length of betas given to __init__(). Args: model (collable): A callable that takes x_t and t and predict noise (and sigma related parameters). shape (list like object): A data shape. sampler (callable): A function to sample x_{t-1} given x_{t} and t. Typically, self.p_sample or self.ddim_sample. noise (collable): A noise generator. If None, F.randn(shape) will be used. interval (int): If > 0, all intermediate results at every `interval` step will be returned as a list. e.g. if interval = 10, the predicted results at {10, 20, 30, ...} will be returned. progress (bool): If True, tqdm will be used to show the sampling progress. Returns: - x_0 (nn.Variable): the final sampled result of x_0 - samples (a list of nn.Variable): the sampled results at every `interval` - pred_x_starts (a list of nn.Variable): the predicted x_0 from each x_t at every `interval`: """ T = self.num_timesteps indices = list(range(T))[::-1] samples = [] pred_x_starts = [] if progress: from tqdm.auto import tqdm indices = tqdm(indices) if without_auto_forward: if noise is None: noise = np.random.randn(*shape) else: assert isinstance(noise, np.ndarray) assert noise.shape == shape x_t = nn.Variable.from_numpy_array(noise) t = nn.Variable.from_numpy_array([T - 1 for _ in range(shape[0])]) # build graph y, pred_x_start = sampler(model, x_t, t) up_x_t = F.assign(x_t, y) up_t = F.assign(t, t - 1) update = F.sink(up_x_t, up_t) cnt = 0 for step in indices: y.forward(clear_buffer=True) update.forward(clear_buffer=True) cnt += 1 if dump_interval > 0 and cnt % dump_interval == 0: samples.append((step, y.d.copy())) pred_x_starts.append((step, pred_x_start.d.copy())) else: with nn.auto_forward(): if noise is None: x_t = F.randn(shape=shape) else: assert isinstance(noise, np.ndarray) assert noise.shape == shape x_t = nn.Variable.from_numpy_array(noise) cnt = 0 for step in indices: t = F.constant(step, shape=(shape[0], )) x_t, pred_x_start = sampler( model, x_t, t, no_noise=step == 0) cnt += 1 if dump_interval > 0 and cnt % dump_interval == 0: samples.append((step, x_t.d.copy())) pred_x_starts.append((step, pred_x_start.d.copy())) assert x_t.shape == shape return x_t.d.copy(), samples, pred_x_starts
def build_static_graph(self): real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size, self.img_size)) noises = [ F.randn(shape=(self.batch_size, self.config['latent_dim'])) for _ in range(2) ] if self.few_shot_config['common']['type'] == 'cdc': NT_class = NoiseTop(n_train=self.train_loader.size, latent_dim=self.config['latent_dim'], batch_size=self.batch_size) noises = NT_class() self.PD_switch_var = NT_class.PD_switch_var if self.config['regularize_gen']: fake_img, dlatents = self.generator(self.batch_size, noises, return_latent=True) else: fake_img = self.generator(self.batch_size, noises) fake_img_test = self.generator_ema(self.batch_size, noises) if self.few_shot_config['common']['type'] != 'cdc': fake_disc_out = self.discriminator(fake_img) real_disc_out = self.discriminator(real_img) disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out) gen_loss = 0 if self.few_shot_config['common']['type'] == 'cdc': fake_img_s = self.generator_s(self.batch_size, noises) cdc_loss = CrossDomainCorrespondence( fake_img, fake_img_s, _choice_num=self.few_shot_config['cdc']['feature_num'], _layer_fix_switch=self.few_shot_config['cdc']['layer_fix']) gen_loss += self.few_shot_config['cdc']['lambda'] * cdc_loss # --- PatchDiscriminator --- fake_disc_out, fake_feature_var = self.discriminator( fake_img, patch_switch=True, index=0) real_disc_out, real_feature_var = self.discriminator( real_img, patch_switch=True, index=0) disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out) disc_loss_patch = disc_logistic_loss(fake_feature_var, real_feature_var) disc_loss += self.PD_switch_var * disc_loss_patch gen_loss += gen_nonsaturating_loss(fake_disc_out) var_name_list = [ 'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss', 'fake_disc_out', 'real_disc_out', 'fake_img_test' ] var_list = [ real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out, real_disc_out, fake_img_test ] if self.config['regularize_gen']: dlatents.need_grad = True mean_path_length = nn.Variable() pl_reg, path_mean, _ = gen_path_regularize( fake_img=fake_img, latents=dlatents, mean_path_length=mean_path_length) path_mean_update = F.assign(mean_path_length, path_mean) path_mean_update.name = 'path_mean_update' pl_reg += 0 * path_mean_update gen_loss_reg = gen_loss + pl_reg var_name_list.append('gen_loss_reg') var_list.append(gen_loss_reg) if self.config['regularize_disc']: real_img.need_grad = True real_disc_out = self.discriminator(real_img) disc_loss_reg = disc_loss + self.config[ 'r1_coeff'] * 0.5 * disc_r1_loss( real_disc_out, real_img) * self.config['disc_reg_step'] real_img.need_grad = False var_name_list.append('disc_loss_reg') var_list.append(disc_loss_reg) Parameters = namedtuple('Parameters', var_name_list) self.parameters = Parameters(*var_list)
def ema_update(p_ema, p_train): return F.assign(p_ema, ema_decay * p_ema + (1. - ema_decay) * p_train)