示例#1
0
 def VT(self, M):
     A = T.diag(M)  # GP_old is in R^{n*n} having the output gp kernel
     # of all pairs of data in the data set
     B = A * A[:, None]
     C = T.sqrt(B)  # in R^{n*n}
     D = M / C  # this is lamblda in ReLU analyrucal formula
     E = T.clip(D, -1, 1)  # clipping E between -1 and 1 for numerical stability.
     F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C
     G = (np.pi - T.arccos(E)) / (2 * np.pi)
     return F,G
 def alg1_VT_dep(self, M): #here i will use M as the previous little q ////// NxN, same value for every row
     A = T.diag(M)  # GP_old is in R^{n*n} having the output gp kernel
     # of all pairs of data in the data set
     B = A * A[:, None]
     C = T.sqrt(B)  # in R^{n*n}
     D = M / C  # this is lambda in ReLU analyrucal formula (c in alg)
     E = T.clip(D, -1, 1)  # clipping E between -1 and 1 for numerical stability.
     F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C
     G = (np.pi - T.arccos(E)) / (2 * np.pi)
     return F,G
示例#3
0
def RNTK_relu(x, RNTK_old, GP_old, param, output):
    sw = param["sigmaw"]
    su = param["sigmau"]
    sb = param["sigmab"]
    sv = param["sigmav"]

    a = T.diag(GP_old)  # GP_old is in R^{n*n} having the output gp kernel
    # of all pairs of data in the data set
    B = a * a[:, None]
    C = T.sqrt(B)  # in R^{n*n}
    D = GP_old / C  # this is lamblda in ReLU analyrucal formula
    # clipping E between -1 and 1 for numerical stability.
    E = T.clip(D, -1, 1)
    F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E**2)) * C
    G = (np.pi - T.arccos(E)) / (2 * np.pi)
    if output:
        GP_new = sv**2 * F
        RNTK_new = sv**2.0 * RNTK_old * G + GP_new
    else:
        X = x * x[:, None]
        GP_new = sw**2 * F + (su**2 / m) * X + sb**2
        RNTK_new = sw**2.0 * RNTK_old * G + GP_new
    return RNTK_new, GP_new
示例#4
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        n=1,
        clip_ratio=0.2,
        entcoeff=0.01,
        target_kl=0.01,
        train_pi_iters=4,
        train_v_iters=4,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.extras = {"logprob": ()}
        self.entcoeff = entcoeff

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        ret_ph = T.Placeholder((batch_size, ), "float32")
        adv_ph = T.Placeholder((batch_size, ), "float32")
        act_ph = T.Placeholder((batch_size, num_actions), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if continuous:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
        with symjax.Scope("old_actor"):
            old_logits = actor(state_ph)
            if continuous:
                old_logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )

        if not continuous:
            pi = Categorical(logits=logits)
        else:
            pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

        actions = T.clip(pi.sample(), -2, 2)  # pi

        actor_params = actor_params = symjax.get_variables(scope="/actor/")
        old_actor_params = actor_params = symjax.get_variables(
            scope="/old_actor/")

        self.update_target = symjax.function(
            updates={o: a
                     for o, a in zip(old_actor_params, actor_params)})

        # PPO objectives
        # pi(a|s) / pi_old(a|s)
        pi_log_prob = pi.log_prob(act_ph)
        old_pi_log_prob = pi_log_prob.clone({
            logits: old_logits,
            logstd: old_logstd
        })

        ratio = T.exp(pi_log_prob - old_pi_log_prob)
        surr1 = ratio * adv_ph
        surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio)

        pi_loss = -T.minimum(surr1, surr2).mean()
        # ent_loss = pi.entropy().mean() * self.entcoeff

        with symjax.Scope("critic"):
            v = critic(state_ph)
        # critic loss
        v_loss = ((ret_ph - v)**2).mean()

        # Info (useful to watch during learning)

        # a sample estimate for KL
        approx_kl = (old_pi_log_prob - pi_log_prob).mean()
        # a sample estimate for entropy
        # approx_ent = -logprob_given_actions.mean()
        # clipped = T.logical_or(
        #     ratio > (1 + clip_ratio), ratio < (1 - clip_ratio)
        # )
        # clipfrac = clipped.astype("float32").mean()

        with symjax.Scope("update_pi"):
            print(len(actor_params), "actor parameters")
            nn.optimizers.Adam(
                pi_loss,
                self.lr,
                params=actor_params,
            )
        with symjax.Scope("update_v"):
            critic_params = symjax.get_variables(scope="/critic/")
            print(len(critic_params), "critic parameters")

            nn.optimizers.Adam(
                v_loss,
                self.lr,
                params=critic_params,
            )

        self.get_params = symjax.function(outputs=critic_params)

        self.learn_pi = symjax.function(
            state_ph,
            act_ph,
            adv_ph,
            outputs=[pi_loss, approx_kl],
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            ret_ph,
            outputs=v_loss,
            updates=symjax.get_updates(scope="/update_v/"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_v = v.clone({state_ph: single_state})
        single_sample = actions.clone({state_ph: single_state})

        self._act = symjax.function(single_state, outputs=single_sample)
        self._get_v = symjax.function(single_state, outputs=single_v)

        single_action = T.Placeholder((1, num_actions), "float32")
示例#5
0
    def __init__(
        self,
        state_dim,
        action_dim,
        lr,
        gamma,
        K_epochs,
        eps_clip,
        actor,
        critic,
        batch_size,
        continuous=True,
    ):
        self.lr = lr
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        state = T.Placeholder((batch_size, ) + state_dim, "float32")

        reward = T.Placeholder((batch_size, ), "float32")
        old_action_logprobs = T.Placeholder((batch_size, ), "float32")

        logits = actor(state)

        if not continuous:
            given_action = T.Placeholder((batch_size, ), "int32")
            dist = Categorical(logits=logits)
        else:
            mean = T.tanh(logits[:, :logits.shape[1] // 2])
            std = T.exp(logits[:, logits.shape[1] // 2:])
            given_action = T.Placeholder((batch_size, action_dim), "float32")
            dist = MultivariateNormal(mean=mean, diag_std=std)

        sample = dist.sample()
        sample_logprobs = dist.log_prob(sample)

        self._act = symjax.function(state, outputs=[sample, sample_logprobs])

        given_action_logprobs = dist.log_prob(given_action)

        # Finding the ratio (pi_theta / pi_theta__old):
        ratios = T.exp(sample_logprobs - old_action_logprobs)
        ratios = T.clip(ratios, None, 1 + self.eps_clip)

        state_value = critic(state)
        advantages = reward - T.stop_gradient(state_value)

        loss = (-T.mean(ratios * advantages) + 0.5 * T.mean(
            (state_value - reward)**2) - 0.0 * dist.entropy().mean())

        print(loss)
        nn.optimizers.Adam(loss, self.lr)

        self.learn = symjax.function(
            state,
            given_action,
            reward,
            old_action_logprobs,
            outputs=T.mean(loss),
            updates=symjax.get_updates(),
        )
示例#6
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        K_epochs=80,
        eps_clip=0.2,
        gamma=0.99,
        entropy_beta=0.01,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape,
                               "float32",
                               name="states")
        actions = T.Placeholder((batch_size, ) + actions_shape,
                                "float32",
                                name="states")
        rewards = T.Placeholder((batch_size, ),
                                "float32",
                                name="discounted_rewards")
        advantages = T.Placeholder((batch_size, ),
                                   "float32",
                                   name="advantages")

        self.target_actor = actor(states, distribution="gaussian")
        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        # Finding the ratio (pi_theta / pi_theta__old) and
        # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf
        with symjax.Scope("policy_loss"):
            ratios = T.exp(
                self.actor.actions.log_prob(actions) -
                self.target_actor.actions.log_prob(actions))
            ratios = T.clip(ratios, 0, 10)
            clipped_ratios = T.clip(ratios, 1 - self.eps_clip,
                                    1 + self.eps_clip)

            surr1 = advantages * ratios
            surr2 = advantages * clipped_ratios

            actor_loss = -(T.minimum(surr1, surr2)).mean()

        with symjax.Scope("monitor"):
            clipfrac = (((ratios > (1 + self.eps_clip)) |
                         (ratios <
                          (1 - self.eps_clip))).astype("float32").mean())
            approx_kl = (self.target_actor.actions.log_prob(actions) -
                         self.actor.actions.log_prob(actions)).mean()

        with symjax.Scope("critic_loss"):
            critic_loss = T.mean((rewards - self.critic.q_values)**2)

        with symjax.Scope("entropy"):
            entropy = self.actor.actions.entropy().mean()

        loss = actor_loss + critic_loss  # - entropy_beta * entropy

        with symjax.Scope("optimizer"):
            nn.optimizers.Adam(
                loss,
                lr,
                params=self.actor.params(True) + self.critic.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            rewards,
            advantages,
            outputs=[actor_loss, critic_loss, clipfrac, approx_kl],
            updates=symjax.get_updates(scope="*optimizer"),
        )

        # initialize target as current
        self.update_target(1)