Ejemplo n.º 1
0
    def network_fn(X, nenv=1):
        nbatch = X.shape[0]
        nsteps = nbatch // nenv

        h = nature_cnn(X, **conv_kwargs)

        M = tf.placeholder(tf.float32, [nbatch])  #mask (done t-1)
        S = tf.placeholder(tf.float32, [nenv, 2 * nlstm])  #states

        xs = batch_to_seq(h, nenv, nsteps)
        ms = batch_to_seq(M, nenv, nsteps)

        if layer_norm:
            h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
        else:
            h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)

        h = seq_to_batch(h5)
        initial_state = np.zeros(S.shape.as_list(), dtype=float)

        return h, {
            'S': S,
            'M': M,
            'state': snew,
            'initial_state': initial_state
        }
Ejemplo n.º 2
0
    def network_fn(input, mask, state):
        input = tf.layers.flatten(input)
        mask = tf.to_float(mask)

        if layer_norm:
            h, next_state = lnlstm([input], [mask[:, None]],
                                   state,
                                   scope='lnlstm',
                                   nh=num_units)
        else:
            h, next_state = lstm([input], [mask[:, None]],
                                 state,
                                 scope='lstm',
                                 nh=num_units)
        h = h[0]
        return h, next_state
Ejemplo n.º 3
0
    def __init__(self,
                 sess,
                 ob_space,
                 ac_space,
                 nenv,
                 nsteps,
                 nstack,
                 reuse=False,
                 nlstm=256):
        nbatch = nenv * nsteps
        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nh, nw, nc * nstack)
        nact = ac_space.n
        X = tf.placeholder(tf.uint8, ob_shape)  # obs
        M = tf.placeholder(tf.float32, [nbatch])  # mask (done t-1)
        S = tf.placeholder(tf.float32, [nenv, nlstm * 2])  # states
        with tf.variable_scope("model", reuse=reuse):
            h = nature_cnn(X)

            # lstm
            xs = batch_to_seq(h, nenv, nsteps)
            ms = batch_to_seq(M, nenv, nsteps)
            h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
            h5 = seq_to_batch(h5)

            pi_logits = fc(h5, 'pi', nact, init_scale=0.01)
            pi = tf.nn.softmax(pi_logits)
            q = fc(h5, 'q', nact)

        a = sample(pi_logits)  # could change this to use self.pi instead
        self.initial_state = np.zeros((nenv, nlstm * 2), dtype=np.float32)
        self.X = X
        self.M = M
        self.S = S
        self.pi = pi  # actual policy params now
        self.q = q

        def step(ob, state, mask, *args, **kwargs):
            # returns actions, mus, states
            a0, pi0, s = sess.run([a, pi, snew], {X: ob, S: state, M: mask})
            return a0, pi0, s

        self.step = step
Ejemplo n.º 4
0
    def network_fn(input, mask, state):
        mask = tf.to_float(mask)
        initializer = ortho_init(np.sqrt(2))

        h = nature_cnn(input, **conv_kwargs)
        h = tf.layers.flatten(h)
        h = tf.layers.dense(h,
                            units=512,
                            activation=tf.nn.relu,
                            kernel_initializer=initializer)

        if layer_norm:
            h, next_state = lnlstm([h], [mask[:, None]],
                                   state,
                                   scope='lnlstm',
                                   nh=num_units)
        else:
            h, next_state = lstm([h], [mask[:, None]],
                                 state,
                                 scope='lstm',
                                 nh=num_units)
        h = h[0]
        return h, next_state