示例#1
0
    def __init__(
        self,
        n_actions,
        n_states,
        Q,
        learning_rate=0.01,
        reward_decay=0.8,
        e_greedy=0.9,
        replace_target_iter=30,
        memory_size=500,
        batch_size=32,
        e_greedy_increment=0.001,
        save_steps=-1,
        output_graph=False,
        record_history=True,
        observation_interval=0.01,
    ):
        self.n_actions = n_actions
        self.n_states = n_states
        self.lr = learning_rate
        self.reward_decay = reward_decay
        self.epsilon_max = e_greedy
        self.replace_target_iter = replace_target_iter
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.epsilon_increment = e_greedy_increment
        self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max
        self.record_history = record_history
        self.observation_interval = observation_interval

        # total learning step
        self.learn_step_counter = 0
        self.action_step_counter = 0

        # initialize zero memory [s, a, r, s_]
        self.memory = np.zeros((self.memory_size, n_states * 2 + 2))
        self.memory_counter = 0
        # consist of [target_net, evaluate_net]
        self.build_net(Q)

        # save data
        self.save_steps = save_steps
        self.steps = 0

        t_params = symjax.get_variables(scope="target_net")
        e_params = symjax.get_variables(scope="eval_net")

        replacement = {t: e for t, e in zip(t_params, e_params)}
        self.replace = symjax.function(updates=replacement)

        self.cost_history = []
        self._tmp_cost_history = []
示例#2
0
def test_accessing_variables():
    sj.current_graph().reset()
    w1 = symjax.tensor.Variable(1.0, trainable=True)
    w2 = symjax.tensor.Variable(1.0, trainable=True)
    w3 = symjax.tensor.Variable(1.0, trainable=False)

    v = symjax.get_variables("*", trainable=True)
    assert w1 in v and w2 in v and w3 not in v

    v = symjax.get_variables("*", trainable=False)
    assert w1 not in v and w2 not in v and w3 in v

    v = symjax.get_variables("*test")
    assert len(v) == 0
示例#3
0
    def train(self, buffer, *args, **kwargs):

        indices = list(range(buffer.length))
        states, actions, rewards, advantages = buffer.sample(
            indices,
            ["state", "action", "reward-to-go", "advantage"],
        )

        # Optimize policy for K epochs:
        advantages -= advantages.mean()
        advantages /= advantages.std()

        for _ in range(self.K_epochs):

            for s, a, r, adv in symjax.data.utils.batchify(
                    states,
                    actions,
                    rewards,
                    advantages,
                    batch_size=self.batch_size,
            ):

                loss = self._train(s, a, r, adv)

        print([v.value for v in symjax.get_variables(name="logsigma")])
        buffer.reset_data()
        self.update_target(1)
        return loss
示例#4
0
    def create_updates(
        self,
        grads_or_loss,
        learning_rate,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-8,
        params=None,
    ):

        if params is None:
            params = symjax.get_variables(trainable=True)

        grads = self._get_grads(grads_or_loss, params)

        # get the learning rate
        if callable(learning_rate):
            learning_rate = learning_rate()

        updates = dict()
        for param, grad in zip(params, grads):
            m = symjax.nn.schedules.ExponentialMovingAverage(grad, beta1)[0]
            v = symjax.nn.schedules.ExponentialMovingAverage(grad**2, beta2)[0]
            update = m / (tensor.sqrt(v) + epsilon)
            updates[param] = param - learning_rate * update

        self.add_updates(updates)
示例#5
0
    def _get_variables(self, loss):

        params = symjax.get_variables(trainable=True)

        params = [
            p for p in params if symjax.current_graph().is_connected(p, loss)
        ]
        return params
示例#6
0
文件: agents.py 项目: SymJAX/SymJAX
    def __init__(self, states, actions=None):
        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:],
                              "float32",
                              name="critic_state")
        if actions:
            self.action_shape = actions.shape[1:]
            action = T.Placeholder((1, ) + actions.shape[1:],
                                   "float32",
                                   name="critic_action")
            action_shape = action.shape[1:]

            with symjax.Scope("critic"):
                q_values = self.create_network(states, actions)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state, actions: action})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states, actions]
            input = [state, action]
            self.actions = actions
            self.action = action

        else:
            with symjax.Scope("critic"):
                q_values = self.create_network(states)
                if q_values.ndim == 2:
                    assert q_values.shape[1] == 1
                    q_values = q_values[:, 0]
                q_value = q_values.clone({states: state})
                self._params = symjax.get_variables(
                    trainable=None, scope=symjax.current_graph().scope_name)

            inputs = [states]
            input = [state]

        self.q_values = q_values
        self.state = state
        self.states = states

        self._get_q_values = symjax.function(*inputs, outputs=q_values)
        self._get_q_value = symjax.function(*input, outputs=q_value[0])
示例#7
0
文件: agents.py 项目: SymJAX/SymJAX
    def __init__(self, states, actions_distribution=None, name="actor"):

        self.state_shape = states.shape[1:]
        state = T.Placeholder((1, ) + states.shape[1:], "float32")
        self.actions_distribution = actions_distribution

        with symjax.Scope(name):
            if actions_distribution == symjax.probabilities.Normal:

                means, covs = self.create_network(states)

                actions = actions_distribution(means, cov=covs)
                samples = actions.sample()
                samples_log_prob = actions.log_prob(samples)

                action = symjax.probabilities.MultivariateNormal(
                    means.clone({states: state}),
                    cov=covs.clone({states: state}),
                )
                sample = self.action.sample()
                sample_log_prob = self.action.log_prob(sample)

                self._get_actions = symjax.function(
                    states, outputs=[samples, samples_log_prob])
                self._get_action = symjax.function(
                    state,
                    outputs=[sample[0], sample_log_prob[0]],
                )
            elif actions_distribution is None:
                actions = self.create_network(states)
                action = actions.clone({states: state})

                self._get_actions = symjax.function(states, outputs=actions)
                self._get_action = symjax.function(state, outputs=action[0])

            self._params = symjax.get_variables(
                trainable=None, scope=symjax.current_graph().scope_name)
        self.actions = actions
        self.state = state
        self.action = action
示例#8
0
def test_bn():
    np.random.seed(0)
    sj.current_graph().reset()
    BATCH_SIZE = 5
    DIM = 2
    input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input")
    deterministic = T.Placeholder((), "bool", name="deterministic")

    bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic)

    update = sj.function(input,
                         deterministic,
                         outputs=bn,
                         updates=sj.get_updates())
    avmean = symjax.get_variables(trainable=None)[-3]
    print(avmean)
    get_stats = sj.function(input, outputs=avmean[0])

    data = np.random.randn(50, DIM) * 4 + 2

    true_means = [np.zeros(DIM)]
    actual_means = [np.zeros(DIM)]

    for i in range(10):
        batch = data[BATCH_SIZE * i:BATCH_SIZE * (i + 1)]

        output = update(batch, 0)
        assert np.allclose(
            output,
            (batch - batch.mean(0)) / np.sqrt(0.001 + batch.var(0)),
            1e-4,
        )

        actual_means.append(get_stats(batch))
        true_means.append(0.99 * true_means[-1] + 0.01 * batch.mean(0))

    true_means = np.array(true_means)
    actual_means = np.array(actual_means).squeeze()
    assert np.allclose(true_means, actual_means)
示例#9
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        n=1,
        clip_ratio=0.2,
        entcoeff=0.01,
        target_kl=0.01,
        train_pi_iters=4,
        train_v_iters=4,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        self.extras = {"logprob": ()}
        self.entcoeff = entcoeff

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        ret_ph = T.Placeholder((batch_size, ), "float32")
        adv_ph = T.Placeholder((batch_size, ), "float32")
        act_ph = T.Placeholder((batch_size, num_actions), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if continuous:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
        with symjax.Scope("old_actor"):
            old_logits = actor(state_ph)
            if continuous:
                old_logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )

        if not continuous:
            pi = Categorical(logits=logits)
        else:
            pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

        actions = T.clip(pi.sample(), -2, 2)  # pi

        actor_params = actor_params = symjax.get_variables(scope="/actor/")
        old_actor_params = actor_params = symjax.get_variables(
            scope="/old_actor/")

        self.update_target = symjax.function(
            updates={o: a
                     for o, a in zip(old_actor_params, actor_params)})

        # PPO objectives
        # pi(a|s) / pi_old(a|s)
        pi_log_prob = pi.log_prob(act_ph)
        old_pi_log_prob = pi_log_prob.clone({
            logits: old_logits,
            logstd: old_logstd
        })

        ratio = T.exp(pi_log_prob - old_pi_log_prob)
        surr1 = ratio * adv_ph
        surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio)

        pi_loss = -T.minimum(surr1, surr2).mean()
        # ent_loss = pi.entropy().mean() * self.entcoeff

        with symjax.Scope("critic"):
            v = critic(state_ph)
        # critic loss
        v_loss = ((ret_ph - v)**2).mean()

        # Info (useful to watch during learning)

        # a sample estimate for KL
        approx_kl = (old_pi_log_prob - pi_log_prob).mean()
        # a sample estimate for entropy
        # approx_ent = -logprob_given_actions.mean()
        # clipped = T.logical_or(
        #     ratio > (1 + clip_ratio), ratio < (1 - clip_ratio)
        # )
        # clipfrac = clipped.astype("float32").mean()

        with symjax.Scope("update_pi"):
            print(len(actor_params), "actor parameters")
            nn.optimizers.Adam(
                pi_loss,
                self.lr,
                params=actor_params,
            )
        with symjax.Scope("update_v"):
            critic_params = symjax.get_variables(scope="/critic/")
            print(len(critic_params), "critic parameters")

            nn.optimizers.Adam(
                v_loss,
                self.lr,
                params=critic_params,
            )

        self.get_params = symjax.function(outputs=critic_params)

        self.learn_pi = symjax.function(
            state_ph,
            act_ph,
            adv_ph,
            outputs=[pi_loss, approx_kl],
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            ret_ph,
            outputs=v_loss,
            updates=symjax.get_updates(scope="/update_v/"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_v = v.clone({state_ph: single_state})
        single_sample = actions.clone({state_ph: single_state})

        self._act = symjax.function(single_state, outputs=single_sample)
        self._get_v = symjax.function(single_state, outputs=single_v)

        single_action = T.Placeholder((1, num_actions), "float32")
示例#10
0
    def __init__(
        self,
        env_fn,
        actor,
        critic,
        gamma=0.99,
        tau=0.01,
        lr=1e-3,
        batch_size=32,
        epsilon=0.1,
        epsilon_decay=1 / 1000,
        min_epsilon=0.01,
        reward=None,
    ):

        # comment out this line if you don't want to record a video of the agent
        # if save_folder is not None:
        # test_env = gym.wrappers.Monitor(test_env)

        # get size of state space and action space
        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_max = env.action_space.high[0]
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.gamma = gamma
        self.continuous = continuous
        self.observ_min = np.clip(env.observation_space.low, -20, 20)
        self.observ_max = np.clip(env.observation_space.high, -20, 20)
        self.env = env
        self.reward = reward

        # state
        state = T.Placeholder((batch_size, num_states), "float32")
        gradients = T.Placeholder((batch_size, num_actions), "float32")
        action = T.Placeholder((batch_size, num_actions), "float32")
        target = T.Placeholder((batch_size, 1), "float32")

        with symjax.Scope("actor_critic"):

            scaled_out = action_max * actor(state)
            Q = critic(state, action)

        a_loss = -T.sum(gradients * scaled_out)
        q_loss = T.mean((Q - target)**2)

        nn.optimizers.Adam(a_loss + q_loss, lr)

        self.update = symjax.function(
            state,
            action,
            target,
            gradients,
            outputs=[a_loss, q_loss],
            updates=symjax.get_updates(),
        )
        g = symjax.gradients(T.mean(Q), [action])[0]
        self.get_gradients = symjax.function(state, action, outputs=g)

        # also create the target variants
        with symjax.Scope("actor_critic_target"):
            scaled_out_target = action_max * actor(state)
            Q_target = critic(state, action)

        self.actor_predict = symjax.function(state, outputs=scaled_out)
        self.actor_predict_target = symjax.function(state,
                                                    outputs=scaled_out_target)
        self.critic_predict = symjax.function(state, action, outputs=Q)
        self.critic_predict_target = symjax.function(state,
                                                     action,
                                                     outputs=Q_target)

        t_params = symjax.get_variables(scope="/actor_critic_target/*")
        params = symjax.get_variables(scope="/actor_critic/*")
        replacement = {
            t: tau * e + (1 - tau) * t
            for t, e in zip(t_params, params)
        }
        self.update_target = symjax.function(updates=replacement)

        single_state = T.Placeholder((1, num_states), "float32")
        if not continuous:
            scaled_out = clean_action.argmax(-1)

        self.act = symjax.function(single_state,
                                   outputs=scaled_out.clone(
                                       {state: single_state})[0])
示例#11
0
def test_learn_bn():

    symjax.current_graph().reset()
    import tensorflow as tf
    from tensorflow.keras import layers
    import symjax.nn as nn

    np.random.seed(0)

    batch_size = 128

    W = np.random.randn(5, 5, 3, 2)
    W2 = np.random.randn(2 * 28 * 28, 10)

    inputs = layers.Input(shape=(3, 32, 32))
    out = layers.Permute((2, 3, 1))(inputs)
    out = layers.Conv2D(2,
                        5,
                        activation="linear",
                        kernel_initializer=lambda *args, **kwargs: W)(out)
    out = layers.BatchNormalization(-1)(out)
    out = layers.Activation("relu")(out)
    out = layers.Flatten()(out)
    out = layers.Dense(10,
                       activation="linear",
                       kernel_initializer=lambda *args, **kwargs: W2)(out)
    model = tf.keras.Model(inputs, out)
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

    input = T.Placeholder((batch_size, 3, 32, 32), "float32")
    label = T.Placeholder((batch_size, ), "int32")
    deterministic = T.Placeholder((), "bool")

    conv = nn.layers.Conv2D(input, 2, (5, 5), W=W.transpose((3, 2, 0, 1)))
    out = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic)
    out = nn.relu(out)
    out = nn.layers.Dense(out.transpose((0, 2, 3, 1)), 10, W=W2.T)
    loss = nn.losses.sparse_softmax_crossentropy_logits(label, out).mean()
    # V2 = T.Variable(W2)
    # loss = input.sum() * V2.sum()
    # out=loss
    nn.optimizers.SGD(loss, 0.001)
    f = symjax.function(input, label, deterministic, outputs=[loss, out])
    g = symjax.function(
        input,
        label,
        deterministic,
        outputs=symjax.gradients(loss, symjax.get_variables(trainable=True)),
    )
    train = symjax.function(
        input,
        label,
        deterministic,
        outputs=loss,
        updates=symjax.get_updates(),
    )

    for epoch in range(10):
        # generate some random inputs and labels
        x = np.random.randn(batch_size, 3, 32, 32)
        y = np.random.randint(0, 10, size=batch_size)

        # test predictions during testing mode
        preds = model(x, training=False)
        nb = np.isclose(preds, f(x, y, 1)[1], atol=1e-3).mean()
        print("preds not training", nb)
        assert nb > 0.8

        # test prediction during training mode
        # now get the gradients
        with tf.GradientTape() as tape:
            preds = model(x, training=True)
            loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds))

        nb = np.isclose(preds, f(x, y, 0)[1], atol=1e-3).mean()
        print("preds training", nb)
        assert nb > 0.8

        # test loss function during training
        losss = f(x, y, 0)[0]
        print("losses", losss, loss)
        assert np.abs(losss - loss) < 1e-3

        # test the gradients
        grads = tape.gradient(loss, model.trainable_variables)
        sj_grads = g(x, y, 0)

        for tf_g, sj_g in zip(grads, sj_grads):
            if sj_g.ndim == 4:
                sj_g = sj_g.transpose((2, 3, 1, 0))
            else:
                sj_g = sj_g.transpose()

            print(sj_g.shape, tf_g.shape)
            nb = np.isclose(
                np.reshape(sj_g, -1),
                np.reshape(tf_g, -1),
                atol=1e-3,
            ).mean()
            print("grads training", nb)
            assert nb >= 0.5

        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train(x, y, 0)
示例#12
0
 def variables(self, trainable=True):
     return symjax.get_variables(scope=self.scope, trainable=trainable)
示例#13
0
    def __init__(
        self,
        env,
        actor,
        critic,
        lr=1e-4,
        batch_size=32,
        train_pi_iters=80,
        train_v_iters=80,
    ):

        num_states = env.observation_space.shape[0]
        continuous = type(env.action_space) == gym.spaces.box.Box

        if continuous:
            num_actions = env.action_space.shape[0]
            action_min = env.action_space.low
            action_max = env.action_space.high
        else:
            num_actions = env.action_space.n
            action_max = 1

        self.batch_size = batch_size
        self.num_states = num_states
        self.num_actions = num_actions
        self.state_dim = (batch_size, num_states)
        self.action_dim = (batch_size, num_actions)
        self.continuous = continuous
        self.lr = lr
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.extras = {}

        state_ph = T.Placeholder((batch_size, num_states), "float32")
        rew_ph = T.Placeholder((batch_size, ), "float32")

        with symjax.Scope("actor"):
            logits = actor(state_ph)
            if not continuous:
                pi = Categorical(logits=logits)
            else:
                logstd = T.Variable(
                    -0.5 * np.ones(num_actions, dtype=np.float32),
                    name="logstd",
                )
                pi = MultivariateNormal(mean=logits, diag_log_std=logstd)

            actions = pi.sample()  # pi
            actions_log_prob = pi.log_prob(actions)  # logp

        with symjax.Scope("critic"):
            critic_value = critic(state_ph)

        # AC objectives
        diff = rew_ph - critic_value
        actor_loss = -(actions_log_prob * diff).mean()
        critic_loss = nn.losses.squared_differences(rew_ph,
                                                    critic_value).mean()

        with symjax.Scope("update_pi"):
            nn.optimizers.Adam(
                actor_loss,
                self.lr,
                params=symjax.get_variables(scope="/actor/"),
            )
        with symjax.Scope("update_v"):
            nn.optimizers.Adam(
                critic_loss,
                self.lr,
                params=symjax.get_variables(scope="/critic/"),
            )

        self.learn_pi = symjax.function(
            state_ph,
            rew_ph,
            outputs=actor_loss,
            updates=symjax.get_updates(scope="/update_pi/"),
        )
        self.learn_v = symjax.function(
            state_ph,
            rew_ph,
            outputs=critic_loss,
            updates=symjax.get_updates(scope="/update_v/*"),
        )

        single_state = T.Placeholder((1, num_states), "float32")
        single_action = actions.clone({state_ph: single_state})[0]
        single_v = critic_value.clone({state_ph: single_state})

        self._act = symjax.function(
            single_state,
            outputs=[single_action, single_v],
        )
示例#14
0
print(g.variables)
# {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/),
#  'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/)}

print(h.variables)
# {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# 'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# 'w': Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)}

print(h.variable("w"))
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)

# now suppose that we did not hold the value for the graph g/h, we can still
# recover a variable based on the name AND the scope

print(symjax.get_variables("/special/inversion/w"))
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)

# now if the exact scope name is not know, it is possible to use smart indexing
# for example suppose we do not remember, then we can get all variables named
# 'w' among scopes

print(symjax.get_variables("*/w"))
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)

# if only part of the scope is known, all the variables of a given scope can
# be retreived

print(symjax.get_variables("/special/*"))
# [Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/),
#  Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/),