Example #1
0
    def _build(self):
        # inference
        self.infer_obs_t = nn.Variable((1,) + self.obs_shape)

        with nn.parameter_scope('trainable'):
            self.infer_policy_t = policy_network(self.infer_obs_t,
                                                 self.action_size, 'actor')

        # training
        self.obss_t = nn.Variable((self.batch_size,) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size,) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        # critic training
        with nn.parameter_scope('trainable'):
            q_t = q_network(self.obss_t, self.acts_t, 'critic')
        with nn.parameter_scope('target'):
            policy_tp1 = policy_network(self.obss_tp1, self.action_size,
                                        'actor')
            q_tp1 = q_network(self.obss_tp1, policy_tp1, 'critic')
        y = self.rews_tp1 + self.gamma * q_tp1 * (1.0 - self.ters_tp1)
        self.critic_loss = F.mean(F.squared_error(q_t, y))

        # actor training
        with nn.parameter_scope('trainable'):
            policy_t = policy_network(self.obss_t, self.action_size, 'actor')
            q_t_with_actor = q_network(self.obss_t, policy_t, 'critic')
        self.actor_loss = -F.mean(q_t_with_actor)

        # get neural network parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()

        # setup optimizers
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)

        with nn.parameter_scope('trainable'):
            trainable_params = nn.get_parameters()
        with nn.parameter_scope('target'):
            target_params = nn.get_parameters()

        # build target update
        update_targets = []
        sync_targets = []
        for key, src in trainable_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)
Example #2
0
def test_assign_forward_backward(seed, ctx, func_name):
    rng = np.random.RandomState(seed)
    dst = nn.Variable((2, 3, 4), need_grad=True)
    src = nn.Variable((2, 3, 4), need_grad=True)

    assign = F.assign(dst, src)

    src.d = rng.rand(2, 3, 4)
    assign.forward()

    # destination variable should be equal to source variable
    assert_allclose(dst.d, src.d)
    # output variable of assign function should be equal to soure variable
    assert_allclose(assign.d, src.d)

    dummy = assign + rng.rand()

    dst.grad.zero()
    src.grad.zero()
    dummy.forward()
    dummy.backward()

    # gradients at destination are identical to gradients at assign operation
    assert not np.all(dst.g == np.zeros((2, 3, 4)))
    assert np.all(dst.g == assign.g)
    assert np.all(src.g == np.zeros((2, 3, 4)))

    # check accum=False
    assign.grad.zero()
    dst.g = rng.rand(2, 3, 4)
    f = assign.parent
    f.forward([dst, src], [assign])
    f.backward([dst, src], [assign], accum=[False])
    assert np.all(dst.g == assign.g)
    assert np.all(src.g == np.zeros((2, 3, 4)))
Example #3
0
    def build_static_graph(self):
        real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size,
                                      self.img_size))
        noises = [
            F.randn(shape=(self.batch_size, self.config['latent_dim']))
            for _ in range(2)
        ]

        if self.config['regularize_gen']:
            fake_img, dlatents = self.generator(self.batch_size,
                                                noises,
                                                return_latent=True)
        else:
            fake_img = self.generator(self.batch_size, noises)
        fake_img_test = self.generator_ema(self.batch_size, noises)

        gen_loss = gen_nonsaturating_loss(self.discriminator(fake_img))

        fake_disc_out = self.discriminator(fake_img)
        real_disc_out = self.discriminator(real_img)
        disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)

        var_name_list = [
            'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss',
            'fake_disc_out', 'real_disc_out', 'fake_img_test'
        ]
        var_list = [
            real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out,
            real_disc_out, fake_img_test
        ]

        if self.config['regularize_gen']:
            dlatents.need_grad = True
            mean_path_length = nn.Variable()
            pl_reg, path_mean, _ = gen_path_regularize(
                fake_img=fake_img,
                latents=dlatents,
                mean_path_length=mean_path_length)
            path_mean_update = F.assign(mean_path_length, path_mean)
            path_mean_update.name = 'path_mean_update'
            pl_reg += 0 * path_mean_update
            gen_loss_reg = gen_loss + pl_reg
            var_name_list.append('gen_loss_reg')
            var_list.append(gen_loss_reg)

        if self.config['regularize_disc']:
            real_img.need_grad = True
            real_disc_out = self.discriminator(real_img)
            disc_loss_reg = disc_loss + self.config[
                'r1_coeff'] * 0.5 * disc_r1_loss(
                    real_disc_out, real_img) * self.config['disc_reg_step']
            real_img.need_grad = False
            var_name_list.append('disc_loss_reg')
            var_list.append(disc_loss_reg)

        Parameters = namedtuple('Parameters', var_name_list)
        self.parameters = Parameters(*var_list)
Example #4
0
def make_ema_updater(scope_ema, scope_cur, ema_decay):
    with nn.parameter_scope(scope_cur):
        params_cur = nn.get_parameters()
    with nn.parameter_scope(scope_ema):
        params_ema = nn.get_parameters()
    update_ema_list = []
    for name in params_ema.keys():
        params_ema_updated = ema_decay * \
            params_ema[name] + (1.0 - ema_decay) * params_cur[name]
        update_ema_list.append(F.assign(params_ema[name], params_ema_updated))
    return F.sink(*update_ema_list)
Example #5
0
 def ema_update(self):
     with nn.parameter_scope('Generator'):
         g_params = nn.get_parameters(grad_only=False)
     with nn.parameter_scope('GeneratorEMA'):
         g_ema_params = nn.get_parameters(grad_only=False)
     update_ema_list = []
     for name in g_ema_params.keys():
         params_ema_updated = self.gen_exp_weight * \
                 g_ema_params[name] + \
                     (1.0 - self.gen_exp_weight) * g_params[name]
         update_ema_list.append(
             F.assign(g_ema_params[name], params_ema_updated))
     return F.sink(*update_ema_list)
Example #6
0
    def _build(self):
        # inference graph
        self.infer_obs_t = nn.Variable((1, ) + self.obs_shape)
        with nn.parameter_scope('trainable'):
            infer_dist = policy_network(self.infer_obs_t, self.action_size,
                                        'actor')
        self.infer_act_t, _ = _squash_action(infer_dist)
        self.deterministic_act_t = infer_dist.mean()

        # training graph
        self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        with nn.parameter_scope('trainable'):
            self.log_temp = get_parameter_or_create('temp', [1, 1],
                                                    ConstantInitializer(0.0))
            dist_t = policy_network(self.obss_t, self.action_size, 'actor')
            dist_tp1 = policy_network(self.obss_tp1, self.action_size, 'actor')
            squashed_act_t, log_prob_t = _squash_action(dist_t)
            squashed_act_tp1, log_prob_tp1 = _squash_action(dist_tp1)
            q1_t = q_network(self.obss_t, self.acts_t, 'critic/1')
            q2_t = q_network(self.obss_t, self.acts_t, 'critic/2')
            q1_t_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/1')
            q2_t_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/2')

        with nn.parameter_scope('target'):
            q1_tp1 = q_network(self.obss_tp1, squashed_act_tp1, 'critic/1')
            q2_tp1 = q_network(self.obss_tp1, squashed_act_tp1, 'critic/2')

        # q function loss
        q_tp1 = F.minimum2(q1_tp1, q2_tp1)
        entropy_tp1 = F.exp(self.log_temp) * log_prob_tp1
        mask = (1.0 - self.ters_tp1)
        q_target = self.rews_tp1 + self.gamma * (q_tp1 - entropy_tp1) * mask
        q_target.need_grad = False
        q1_loss = 0.5 * F.mean(F.squared_error(q1_t, q_target))
        q2_loss = 0.5 * F.mean(F.squared_error(q2_t, q_target))
        self.critic_loss = q1_loss + q2_loss

        # policy function loss
        q_t = F.minimum2(q1_t_with_actor, q2_t_with_actor)
        entropy_t = F.exp(self.log_temp) * log_prob_t
        self.actor_loss = F.mean(entropy_t - q_t)

        # temperature loss
        temp_target = log_prob_t - self.action_size
        temp_target.need_grad = False
        self.temp_loss = -F.mean(F.exp(self.log_temp) * temp_target)

        # trainable parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()
        # target parameters
        with nn.parameter_scope('target/critic'):
            target_params = nn.get_parameters()

        # target update
        update_targets = []
        sync_targets = []
        for key, src in critic_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)

        # setup solvers
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)
        self.temp_solver = S.Adam(self.temp_lr)
        self.temp_solver.set_parameters({'temp': self.log_temp})
Example #7
0
    def _build(self):
        # inference graph
        self.infer_obs_t = nn.Variable((1, ) + self.obs_shape)
        with nn.parameter_scope('trainable'):
            infer_dist = policy_network(self.infer_obs_t, self.action_size,
                                        'actor')
        self.infer_act_t, _ = _squash_action(infer_dist)
        self.deterministic_act_t = infer_dist.mean()

        # training graph
        self.obss_t = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, ) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        with nn.parameter_scope('trainable'):
            dist = policy_network(self.obss_t, self.action_size, 'actor')
            squashed_act_t, log_prob_t = _squash_action(dist)
            v_t = v_network(self.obss_t, 'value')
            q_t1 = q_network(self.obss_t, self.acts_t, 'critic/1')
            q_t2 = q_network(self.obss_t, self.acts_t, 'critic/2')
            q_t1_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/1')
            q_t2_with_actor = q_network(self.obss_t, squashed_act_t,
                                        'critic/2')

        with nn.parameter_scope('target'):
            v_tp1 = v_network(self.obss_tp1, 'value')

        # value loss
        q_t = F.minimum2(q_t1_with_actor, q_t2_with_actor)
        v_target = q_t - log_prob_t
        v_target.need_grad = False
        self.value_loss = 0.5 * F.mean(F.squared_error(v_t, v_target))

        # q function loss
        scaled_rews_tp1 = self.rews_tp1 * self.reward_scale
        q_target = scaled_rews_tp1 + self.gamma * v_tp1 * (1.0 - self.ters_tp1)
        q_target.need_grad = False
        q1_loss = 0.5 * F.mean(F.squared_error(q_t1, q_target))
        q2_loss = 0.5 * F.mean(F.squared_error(q_t2, q_target))
        self.critic_loss = q1_loss + q2_loss

        # policy function loss
        mean_loss = 0.5 * F.mean(dist.mean()**2)
        logstd_loss = 0.5 * F.mean(F.log(dist.stddev())**2)
        policy_reg_loss = self.policy_reg * (mean_loss + logstd_loss)
        self.objective_loss = F.mean(log_prob_t - q_t)
        self.actor_loss = self.objective_loss + policy_reg_loss

        # trainable parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('value'):
                value_params = nn.get_parameters()
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()
        # target parameters
        with nn.parameter_scope('target/value'):
            target_params = nn.get_parameters()

        # target update
        update_targets = []
        sync_targets = []
        for key, src in value_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)

        # setup solvers
        self.value_solver = S.Adam(self.value_lr)
        self.value_solver.set_parameters(value_params)
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)
Example #8
0
    def __call__(self,
                 batch_size,
                 style_noises,
                 truncation_psi=1.0,
                 return_latent=False,
                 mixing_layer_index=None,
                 dlatent_avg_beta=0.995):

        with nn.parameter_scope(self.global_scope):
            # normalize noise inputs
            for i in range(len(style_noises)):
                style_noises[i] = F.div2(
                    style_noises[i],
                    F.pow_scalar(F.add_scalar(F.mean(style_noises[i]**2.,
                                                     axis=1,
                                                     keepdims=True),
                                              1e-8,
                                              inplace=False),
                                 0.5,
                                 inplace=False))

            # get latent code
            w = [
                mapping_network(style_noises[0],
                                outmaps=self.mapping_network_dim,
                                num_layers=self.mapping_network_num_layers)
            ]
            w += [
                mapping_network(style_noises[1],
                                outmaps=self.mapping_network_dim,
                                num_layers=self.mapping_network_num_layers)
            ]

            dlatent_avg = nn.parameter.get_parameter_or_create(
                name="dlatent_avg", shape=(1, 512))

            # Moving average update of dlatent_avg
            batch_avg = F.mean((w[0] + w[1]) * 0.5, axis=0, keepdims=True)
            update_op = F.assign(
                dlatent_avg, lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            update_op.name = 'dlatent_avg_update'
            dlatent_avg = F.identity(dlatent_avg) + 0 * update_op

            # truncation trick
            w = [lerp(dlatent_avg, _, truncation_psi) for _ in w]

            # generate output from generator
            constant_bc = nn.parameter.get_parameter_or_create(
                name="G_synthesis/4x4/Const/const",
                shape=(1, 512, 4, 4),
                initializer=np.random.randn(1, 512, 4, 4).astype(np.float32))
            constant_bc = F.broadcast(constant_bc,
                                      (batch_size, ) + constant_bc.shape[1:])

            if mixing_layer_index is None:
                mixing_layer_index_var = F.randint(1,
                                                   len(self.resolutions) * 2,
                                                   (1, ))
            else:
                mixing_layer_index_var = F.constant(val=mixing_layer_index,
                                                    shape=(1, ))
            mixing_switch_var = F.clip_by_value(
                F.arange(0,
                         len(self.resolutions) * 2) - mixing_layer_index_var,
                0, 1)
            mixing_switch_var_re = F.reshape(
                mixing_switch_var, (1, mixing_switch_var.shape[0], 1),
                inplace=False)
            w0 = F.reshape(w[0], (batch_size, 1, w[0].shape[1]), inplace=False)
            w1 = F.reshape(w[1], (batch_size, 1, w[0].shape[1]), inplace=False)
            w_mixed = w0 * mixing_switch_var_re + \
                w1 * (1 - mixing_switch_var_re)

            rgb_output = self.synthesis(w_mixed, constant_bc)

            if return_latent:
                return rgb_output, w_mixed
            else:
                return rgb_output
Example #9
0
    def sample_loop(self, model, shape, sampler,
                    noise=None,
                    dump_interval=-1,
                    progress=False,
                    without_auto_forward=False):
        """
        Iteratively Sample data from model from t=T to t=0.
        T is specified as the length of betas given to __init__().

        Args:
            model (collable): 
                A callable that takes x_t and t and predict noise (and sigma related parameters).
            shape (list like object): A data shape.
            sampler (callable): A function to sample x_{t-1} given x_{t} and t. Typically, self.p_sample or self.ddim_sample.
            noise (collable): A noise generator. If None, F.randn(shape) will be used.
            interval (int): 
                If > 0, all intermediate results at every `interval` step will be returned as a list.
                e.g. if interval = 10, the predicted results at {10, 20, 30, ...} will be returned.
            progress (bool): If True, tqdm will be used to show the sampling progress.

        Returns:
            - x_0 (nn.Variable): the final sampled result of x_0
            - samples (a list of nn.Variable): the sampled results at every `interval`
            - pred_x_starts (a list of nn.Variable): the predicted x_0 from each x_t at every `interval`: 
        """
        T = self.num_timesteps
        indices = list(range(T))[::-1]

        samples = []
        pred_x_starts = []

        if progress:
            from tqdm.auto import tqdm
            indices = tqdm(indices)

        if without_auto_forward:
            if noise is None:
                noise = np.random.randn(*shape)
            else:
                assert isinstance(noise, np.ndarray)
                assert noise.shape == shape

            x_t = nn.Variable.from_numpy_array(noise)
            t = nn.Variable.from_numpy_array([T - 1 for _ in range(shape[0])])

            # build graph
            y, pred_x_start = sampler(model, x_t, t)
            up_x_t = F.assign(x_t, y)
            up_t = F.assign(t, t - 1)
            update = F.sink(up_x_t, up_t)

            cnt = 0
            for step in indices:
                y.forward(clear_buffer=True)
                update.forward(clear_buffer=True)

                cnt += 1
                if dump_interval > 0 and cnt % dump_interval == 0:
                    samples.append((step, y.d.copy()))
                    pred_x_starts.append((step, pred_x_start.d.copy()))
        else:
            with nn.auto_forward():
                if noise is None:
                    x_t = F.randn(shape=shape)
                else:
                    assert isinstance(noise, np.ndarray)
                    assert noise.shape == shape
                    x_t = nn.Variable.from_numpy_array(noise)
                cnt = 0
                for step in indices:
                    t = F.constant(step, shape=(shape[0], ))
                    x_t, pred_x_start = sampler(
                        model, x_t, t, no_noise=step == 0)
                    cnt += 1
                    if dump_interval > 0 and cnt % dump_interval == 0:
                        samples.append((step, x_t.d.copy()))
                        pred_x_starts.append((step, pred_x_start.d.copy()))

        assert x_t.shape == shape
        return x_t.d.copy(), samples, pred_x_starts
Example #10
0
    def build_static_graph(self):
        real_img = nn.Variable(shape=(self.batch_size, 3, self.img_size,
                                      self.img_size))
        noises = [
            F.randn(shape=(self.batch_size, self.config['latent_dim']))
            for _ in range(2)
        ]
        if self.few_shot_config['common']['type'] == 'cdc':
            NT_class = NoiseTop(n_train=self.train_loader.size,
                                latent_dim=self.config['latent_dim'],
                                batch_size=self.batch_size)
            noises = NT_class()
            self.PD_switch_var = NT_class.PD_switch_var

        if self.config['regularize_gen']:
            fake_img, dlatents = self.generator(self.batch_size,
                                                noises,
                                                return_latent=True)
        else:
            fake_img = self.generator(self.batch_size, noises)
        fake_img_test = self.generator_ema(self.batch_size, noises)

        if self.few_shot_config['common']['type'] != 'cdc':
            fake_disc_out = self.discriminator(fake_img)
            real_disc_out = self.discriminator(real_img)
            disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)

        gen_loss = 0
        if self.few_shot_config['common']['type'] == 'cdc':
            fake_img_s = self.generator_s(self.batch_size, noises)
            cdc_loss = CrossDomainCorrespondence(
                fake_img,
                fake_img_s,
                _choice_num=self.few_shot_config['cdc']['feature_num'],
                _layer_fix_switch=self.few_shot_config['cdc']['layer_fix'])
            gen_loss += self.few_shot_config['cdc']['lambda'] * cdc_loss
            # --- PatchDiscriminator ---
            fake_disc_out, fake_feature_var = self.discriminator(
                fake_img, patch_switch=True, index=0)
            real_disc_out, real_feature_var = self.discriminator(
                real_img, patch_switch=True, index=0)
            disc_loss = disc_logistic_loss(real_disc_out, fake_disc_out)
            disc_loss_patch = disc_logistic_loss(fake_feature_var,
                                                 real_feature_var)
            disc_loss += self.PD_switch_var * disc_loss_patch

        gen_loss += gen_nonsaturating_loss(fake_disc_out)

        var_name_list = [
            'real_img', 'noises', 'fake_img', 'gen_loss', 'disc_loss',
            'fake_disc_out', 'real_disc_out', 'fake_img_test'
        ]
        var_list = [
            real_img, noises, fake_img, gen_loss, disc_loss, fake_disc_out,
            real_disc_out, fake_img_test
        ]

        if self.config['regularize_gen']:
            dlatents.need_grad = True
            mean_path_length = nn.Variable()
            pl_reg, path_mean, _ = gen_path_regularize(
                fake_img=fake_img,
                latents=dlatents,
                mean_path_length=mean_path_length)
            path_mean_update = F.assign(mean_path_length, path_mean)
            path_mean_update.name = 'path_mean_update'
            pl_reg += 0 * path_mean_update
            gen_loss_reg = gen_loss + pl_reg
            var_name_list.append('gen_loss_reg')
            var_list.append(gen_loss_reg)

        if self.config['regularize_disc']:
            real_img.need_grad = True
            real_disc_out = self.discriminator(real_img)
            disc_loss_reg = disc_loss + self.config[
                'r1_coeff'] * 0.5 * disc_r1_loss(
                    real_disc_out, real_img) * self.config['disc_reg_step']
            real_img.need_grad = False
            var_name_list.append('disc_loss_reg')
            var_list.append(disc_loss_reg)

        Parameters = namedtuple('Parameters', var_name_list)
        self.parameters = Parameters(*var_list)
Example #11
0
 def ema_update(p_ema, p_train):
     return F.assign(p_ema, ema_decay * p_ema + (1. - ema_decay) * p_train)