コード例 #1
0
    def __init__(self, *args, name=None, **kwargs):

        if name is None:
            name = self.__NAME__
        with symjax.Scope(name):
            self.create_updates(*args, **kwargs)
            self._scope_name = symjax.current_graph().scope_name
コード例 #2
0
ファイル: agents.py プロジェクト: SymJAX/SymJAX
    def update_target(self, tau=None):

        if not hasattr(self, "_update_target"):
            with symjax.Scope("update_target"):
                targets = []
                currents = []
                if hasattr(self, "target_actor"):
                    targets += self.target_actor.params(True)
                    currents += self.actor.params(True)
                if hasattr(self, "target_critic"):
                    targets += self.target_critic.params(True)
                    currents += self.critic.params(True)

                _tau = T.Placeholder((), "float32")
                updates = {
                    t: t * (1 - _tau) + a * _tau
                    for t, a in zip(targets, currents)
                }
            self._update_target = symjax.function(_tau, updates=updates)

        if tau is None:
            if not hasattr(self, "tau"):
                raise RuntimeError("tau must be specified")
        tau = tau or self.tau
        self._update_target(tau)
コード例 #3
0
    def __init__(self, *args, name=None, **kwargs):

        if name is None:
            name = self.__NAME__

        with symjax.Scope(name):
            output = self.forward(*args, **kwargs)
            super().__init__(output, 0, _jax_function=jax.numpy.add)
コード例 #4
0
ファイル: agents.py プロジェクト: SymJAX/SymJAX
    def __init__(self, states, actions=None):
        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:],
                              "float32",
                              name="critic_state")
        if actions:
            self.action_shape = actions.shape[1:]
            action = T.Placeholder((1, ) + actions.shape[1:],
                                   "float32",
                                   name="critic_action")
            action_shape = action.shape[1:]

            with symjax.Scope("critic"):
                q_values = self.create_network(states, actions)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state, actions: action})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states, actions]
            input = [state, action]
            self.actions = actions
            self.action = action

        else:
            with symjax.Scope("critic"):
                q_values = self.create_network(states)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states]
            input = [state]

        self.q_values = q_values
        self.state = state
        self.states = states

        self._get_q_values = symjax.function(*inputs, outputs=q_values)
        self._get_q_value = symjax.function(*input, outputs=q_value[0])
コード例 #5
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
コード例 #6
0
ファイル: layers.py プロジェクト: SymJAX/SymJAX
    def __new__(cls, *args, name=None, **kwargs):

        if name is None:
            name = cls.__NAME__

        with symjax.Scope(name):

            output = cls.__init__(cls, *args, **kwargs)

        return output
コード例 #7
0
def test_clone_0():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    with sj.Scope("placing"):
        u = T.Placeholder((), "float32", name="u")
    value = 2 * w * u
    c = value.clone({w: u})
    f = sj.function(u, outputs=value)
    g = sj.function(u, outputs=c)

    assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
コード例 #8
0
ファイル: agents.py プロジェクト: SymJAX/SymJAX
    def __call__(self, action, episode):

        with symjax.Scope("OUProcess"):
            self.episode = T.Variable(1,
                                      "float32",
                                      name="episode",
                                      trainable=False)

        self.noise_scale = self.initial_noise_scale * self.noise_decay**episode

        x = (self.process + self.theta * (self.mean - self.process) * self.dt +
             self.std_dev * np.sqrt(self.dt) *
             np.random.normal(size=action.shape))
        # Store x into process
        # Makes next noise dependent on current one
        self.process = x

        return action + self.noise_scale * self.process
コード例 #9
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        n_episodes,
        episode_length,
        actor,
        lr=1e-3,
        gamma=0.99,
    ):
        self.actor = actor
        self.gamma = gamma
        self.lr = lr
        self.episode_length = episode_length
        self.n_episodes = n_episodes
        self.batch_size = episode_length * n_episodes

        states = T.Placeholder((self.batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32")
        discounted_rewards = T.Placeholder((self.batch_size, ), "float32")

        self.actor = actor(states, distribution="gaussian")

        logprobs = self.actor.actions.log_prob(actions)
        actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes

        with symjax.Scope("REINFORCE_optimizer"):
            nn.optimizers.Adam(
                actor_loss,
                lr,
                params=self.actor.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            discounted_rewards,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/REINFORCE_optimizer"),
        )
コード例 #10
0
ファイル: agents.py プロジェクト: SymJAX/SymJAX
    def __init__(self, states, actions_distribution=None, name="actor"):

        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:], "float32")
        self.actions_distribution = actions_distribution

        with symjax.Scope(name):
            if actions_distribution == symjax.probabilities.Normal:

                means, covs = self.create_network(states)

                actions = actions_distribution(means, cov=covs)
                samples = actions.sample()
                samples_log_prob = actions.log_prob(samples)

                action = symjax.probabilities.MultivariateNormal(
                    means.clone({states: state}),
                    cov=covs.clone({states: state}),
                )
                sample = self.action.sample()
                sample_log_prob = self.action.log_prob(sample)

                self._get_actions = symjax.function(
                    states, outputs=[samples, samples_log_prob])
                self._get_action = symjax.function(
                    state,
                    outputs=[sample[0], sample_log_prob[0]],
                )
            elif actions_distribution is None:
                actions = self.create_network(states)
                action = actions.clone({states: state})

                self._get_actions = symjax.function(states, outputs=actions)
                self._get_action = symjax.function(state, outputs=action[0])

            self._params = symjax.get_variables(
                trainable=None, scope=symjax.current_graph().scope_name)
        self.actions = actions
        self.state = state
        self.action = action
コード例 #11
0
    def __init__(
        self,
        env_fn,
        actor,
        critic,
        gamma=0.99,
        tau=0.01,
        lr=1e-3,
        batch_size=32,
        epsilon=0.1,
        epsilon_decay=1 / 1000,
        min_epsilon=0.01,
        reward=None,
    ):

        # comment out this line if you don't want to record a video of the agent
        # if save_folder is not None:
        # test_env = gym.wrappers.Monitor(test_env)

        # get size of state space and action space
        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_max = env.action_space.high[0]
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.gamma = gamma
        self.continuous = continuous
        self.observ_min = np.clip(env.observation_space.low, -20, 20)
        self.observ_max = np.clip(env.observation_space.high, -20, 20)
        self.env = env
        self.reward = reward

        # state
        state = T.Placeholder((batch_size, num_states), "float32")
        gradients = T.Placeholder((batch_size, num_actions), "float32")
        action = T.Placeholder((batch_size, num_actions), "float32")
        target = T.Placeholder((batch_size, 1), "float32")

        with symjax.Scope("actor_critic"):

            scaled_out = action_max * actor(state)
            Q = critic(state, action)

        a_loss = -T.sum(gradients * scaled_out)
        q_loss = T.mean((Q - target)**2)

        nn.optimizers.Adam(a_loss + q_loss, lr)

        self.update = symjax.function(
            state,
            action,
            target,
            gradients,
            outputs=[a_loss, q_loss],
            updates=symjax.get_updates(),
        )
        g = symjax.gradients(T.mean(Q), [action])[0]
        self.get_gradients = symjax.function(state, action, outputs=g)

        # also create the target variants
        with symjax.Scope("actor_critic_target"):
            scaled_out_target = action_max * actor(state)
            Q_target = critic(state, action)

        self.actor_predict = symjax.function(state, outputs=scaled_out)
        self.actor_predict_target = symjax.function(state,
                                                    outputs=scaled_out_target)
        self.critic_predict = symjax.function(state, action, outputs=Q)
        self.critic_predict_target = symjax.function(state,
                                                     action,
                                                     outputs=Q_target)

        t_params = symjax.get_variables(scope="/actor_critic_target/*")
        params = symjax.get_variables(scope="/actor_critic/*")
        replacement = {
            t: tau * e + (1 - tau) * t
            for t, e in zip(t_params, params)
        }
        self.update_target = symjax.function(updates=replacement)

        single_state = T.Placeholder((1, num_states), "float32")
        if not continuous:
            scaled_out = clean_action.argmax(-1)

        self.act = symjax.function(single_state,
                                   outputs=scaled_out.clone(
                                       {state: single_state})[0])
コード例 #12
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        train_pi_iters=80,
        train_v_iters=80,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.extras = {}

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        rew_ph = T.Placeholder((batch_size, ), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if not continuous:
                pi = Categorical(logits=logits)
            else:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
                pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

            actions = pi.sample()  # pi
            actions_log_prob = pi.log_prob(actions)  # logp

        with symjax.Scope("critic"):
            critic_value = critic(state_ph)

        # AC objectives
        diff = rew_ph - critic_value
        actor_loss = -(actions_log_prob * diff).mean()
        critic_loss = nn.losses.squared_differences(rew_ph,
                                                    critic_value).mean()

        with symjax.Scope("update_pi"):
            nn.optimizers.Adam(
                actor_loss,
                self.lr,
                params=symjax.get_variables(scope="/actor/"),
            )
        with symjax.Scope("update_v"):
            nn.optimizers.Adam(
                critic_loss,
                self.lr,
                params=symjax.get_variables(scope="/critic/"),
            )

        self.learn_pi = symjax.function(
            state_ph,
            rew_ph,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            rew_ph,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="/update_v/*"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_action = actions.clone({state_ph: single_state})[0]
        single_v = critic_value.clone({state_ph: single_state})

        self._act = symjax.function(
            single_state,
            outputs=[single_action, single_v],
        )
コード例 #13
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        K_epochs=80,
        eps_clip=0.2,
        gamma=0.99,
        entropy_beta=0.01,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape,
                               "float32",
                               name="states")
        actions = T.Placeholder((batch_size, ) + actions_shape,
                                "float32",
                                name="states")
        rewards = T.Placeholder((batch_size, ),
                                "float32",
                                name="discounted_rewards")
        advantages = T.Placeholder((batch_size, ),
                                   "float32",
                                   name="advantages")

        self.target_actor = actor(states, distribution="gaussian")
        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        # Finding the ratio (pi_theta / pi_theta__old) and
        # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf
        with symjax.Scope("policy_loss"):
            ratios = T.exp(
                self.actor.actions.log_prob(actions) -
                self.target_actor.actions.log_prob(actions))
            ratios = T.clip(ratios, 0, 10)
            clipped_ratios = T.clip(ratios, 1 - self.eps_clip,
                                    1 + self.eps_clip)

            surr1 = advantages * ratios
            surr2 = advantages * clipped_ratios

            actor_loss = -(T.minimum(surr1, surr2)).mean()

        with symjax.Scope("monitor"):
            clipfrac = (((ratios > (1 + self.eps_clip)) |
                         (ratios <
                          (1 - self.eps_clip))).astype("float32").mean())
            approx_kl = (self.target_actor.actions.log_prob(actions) -
                         self.actor.actions.log_prob(actions)).mean()

        with symjax.Scope("critic_loss"):
            critic_loss = T.mean((rewards - self.critic.q_values)**2)

        with symjax.Scope("entropy"):
            entropy = self.actor.actions.entropy().mean()

        loss = actor_loss + critic_loss  # - entropy_beta * entropy

        with symjax.Scope("optimizer"):
            nn.optimizers.Adam(
                loss,
                lr,
                params=self.actor.params(True) + self.critic.params(True),
            )

        # create the update function
        self._train = symjax.function(
            states,
            actions,
            rewards,
            advantages,
            outputs=[actor_loss, critic_loss, clipfrac, approx_kl],
            updates=symjax.get_updates(scope="*optimizer"),
        )

        # initialize target as current
        self.update_target(1)
コード例 #14
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        n_episodes,
        episode_length,
        actor,
        critic,
        lr=1e-3,
        gamma=0.99,
        train_v_iters=10,
    ):
        self.actor = actor
        self.critic = critic
        self.gamma = gamma
        self.lr = lr
        self.episode_length = episode_length
        self.n_episodes = n_episodes
        self.batch_size = episode_length * n_episodes
        self.train_v_iters = train_v_iters

        states = T.Placeholder((self.batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32")
        discounted_rewards = T.Placeholder((self.batch_size, ), "float32")
        advantages = T.Placeholder((self.batch_size, ), "float32")

        self.actor = actor(states, distribution="gaussian")
        self.critic = critic(states)

        logprobs = self.actor.actions.log_prob(actions)
        actor_loss = -(logprobs * advantages).sum() / n_episodes
        critic_loss = 0.5 * (
            (discounted_rewards - self.critic.q_values)**2).mean()

        with symjax.Scope("actor_optimizer"):
            nn.optimizers.Adam(
                actor_loss,
                lr,
                params=self.actor.params(True),
            )
        with symjax.Scope("critic_optimizer"):
            nn.optimizers.Adam(
                critic_loss,
                lr,
                params=self.critic.params(True),
            )

        # create the update function
        self._train_actor = symjax.function(
            states,
            actions,
            advantages,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/actor_optimizer"),
        )
        # create the update function
        self._train_critic = symjax.function(
            states,
            discounted_rewards,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="*/critic_optimizer"),
        )
コード例 #15
0
    def __init__(
        self,
        state_shape,
        actions_shape,
        batch_size,
        actor,
        critic,
        lr=1e-3,
        gamma=0.99,
        tau=0.01,
    ):

        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.batch_size = batch_size

        states = T.Placeholder((batch_size, ) + state_shape, "float32")
        actions = T.Placeholder((batch_size, ) + actions_shape, "float32")
        self.critic = critic(states, actions)
        self.target_critic = critic(states, actions)

        # create critic loss
        targets = T.Placeholder(self.critic.q_values.shape, "float32")
        critic_loss = ((self.critic.q_values - targets)**2).mean()

        # create optimizer
        with symjax.Scope("critic_optimizer"):
            nn.optimizers.Adam(critic_loss,
                               lr,
                               params=self.critic.params(True))

        # create the update function
        self._train_critic = symjax.function(
            states,
            actions,
            targets,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="*/critic_optimizer"),
        )

        # now create utility function to get the gradients
        grad = symjax.gradients(self.critic.q_values.sum(), actions)
        self._get_critic_gradients = symjax.function(states,
                                                     actions,
                                                     outputs=grad)

        # create actor loss
        self.actor = actor(states)
        self.target_actor = actor(states)
        gradients = T.Placeholder(actions.shape, "float32")
        actor_loss = -(self.actor.actions * gradients).mean()

        # create optimizer
        with symjax.Scope("actor_optimizer"):
            nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True))

        # create the update function
        self._train_actor = symjax.function(
            states,
            gradients,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="*/actor_optimizer"),
        )

        # initialize both networks as the same
        self.update_target(1)
コード例 #16
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        n=1,
        clip_ratio=0.2,
        entcoeff=0.01,
        target_kl=0.01,
        train_pi_iters=4,
        train_v_iters=4,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.extras = {"logprob": ()}
        self.entcoeff = entcoeff

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        ret_ph = T.Placeholder((batch_size, ), "float32")
        adv_ph = T.Placeholder((batch_size, ), "float32")
        act_ph = T.Placeholder((batch_size, num_actions), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if continuous:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
        with symjax.Scope("old_actor"):
            old_logits = actor(state_ph)
            if continuous:
                old_logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )

        if not continuous:
            pi = Categorical(logits=logits)
        else:
            pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

        actions = T.clip(pi.sample(), -2, 2)  # pi

        actor_params = actor_params = symjax.get_variables(scope="/actor/")
        old_actor_params = actor_params = symjax.get_variables(
            scope="/old_actor/")

        self.update_target = symjax.function(
            updates={o: a
                     for o, a in zip(old_actor_params, actor_params)})

        # PPO objectives
        # pi(a|s) / pi_old(a|s)
        pi_log_prob = pi.log_prob(act_ph)
        old_pi_log_prob = pi_log_prob.clone({
            logits: old_logits,
            logstd: old_logstd
        })

        ratio = T.exp(pi_log_prob - old_pi_log_prob)
        surr1 = ratio * adv_ph
        surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio)

        pi_loss = -T.minimum(surr1, surr2).mean()
        # ent_loss = pi.entropy().mean() * self.entcoeff

        with symjax.Scope("critic"):
            v = critic(state_ph)
        # critic loss
        v_loss = ((ret_ph - v)**2).mean()

        # Info (useful to watch during learning)

        # a sample estimate for KL
        approx_kl = (old_pi_log_prob - pi_log_prob).mean()
        # a sample estimate for entropy
        # approx_ent = -logprob_given_actions.mean()
        # clipped = T.logical_or(
        #     ratio > (1 + clip_ratio), ratio < (1 - clip_ratio)
        # )
        # clipfrac = clipped.astype("float32").mean()

        with symjax.Scope("update_pi"):
            print(len(actor_params), "actor parameters")
            nn.optimizers.Adam(
                pi_loss,
                self.lr,
                params=actor_params,
            )
        with symjax.Scope("update_v"):
            critic_params = symjax.get_variables(scope="/critic/")
            print(len(critic_params), "critic parameters")

            nn.optimizers.Adam(
                v_loss,
                self.lr,
                params=critic_params,
            )

        self.get_params = symjax.function(outputs=critic_params)

        self.learn_pi = symjax.function(
            state_ph,
            act_ph,
            adv_ph,
            outputs=[pi_loss, approx_kl],
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            ret_ph,
            outputs=v_loss,
            updates=symjax.get_updates(scope="/update_v/"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_v = v.clone({state_ph: single_state})
        single_sample = actions.clone({state_ph: single_state})

        self._act = symjax.function(single_state, outputs=single_sample)
        self._get_v = symjax.function(single_state, outputs=single_v)

        single_action = T.Placeholder((1, num_actions), "float32")
コード例 #17
0
    def __init__(self, *args, name=None, **kwargs):

        if name is None:
            name = self.__NAME__
        with symjax.Scope(name):
            self.create_updates(*args, **kwargs)