Esempio n. 1
0
    def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False):
        nbatch = nenv*nsteps
        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nh, nw, nc*nstack)
        nact = ac_space.n
        X = tf.placeholder(tf.uint8, ob_shape) #obs
        with tf.variable_scope("model", reuse=reuse):
            h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
            h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
            h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
            h3 = conv_to_fc(h3)
            h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
            pi = fc(h4, 'pi', nact, act=lambda x:x)
            vf = fc(h4, 'v', 1, act=lambda x:x)

        v0 = vf[:, 0]
        a0 = sample(pi)
        self.initial_state = [] #not stateful

        def step(ob, *_args, **_kwargs):
            a, v = sess.run([a0, v0], {X:ob})
            return a, v, [] #dummy state

        def value(ob, *_args, **_kwargs):
            return sess.run(v0, {X:ob})

        self.X = X
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value
    def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False):  # pylint: disable=W0613
        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nh, nw, nc)
        nact = ac_space.n
        X = tf.placeholder(tf.uint8, ob_shape)  # obs
        with tf.variable_scope("model", reuse=reuse):
            h = conv(tf.cast(X, tf.float32) / 255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
            h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
            h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
            h3 = conv_to_fc(h3)
            h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
            pi = fc(h4, 'pi', nact, act=lambda x: x, init_scale=0.01)
            vf = fc(h4, 'v', 1, act=lambda x: x)[:, 0]

        self.pdtype = make_pdtype(ac_space)
        self.pd = self.pdtype.pdfromflat(pi)

        a0 = self.pd.sample()
        neglogp0 = self.pd.neglogp(a0)
        self.initial_state = None

        def step(ob, *_args, **_kwargs):
            a, v, neglogp = sess.run([a0, vf, neglogp0], {X: ob})
            return a, v, self.initial_state, neglogp

        def value(ob, *_args, **_kwargs):
            return sess.run(vf, {X: ob})

        self.X = X
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value
Esempio n. 3
0
    def __init__(self, sess, ob_space, ac_space, nenv,
                    nsteps, nstack, nplayers, map_size=[15,15,32], reuse=False):
        nbatch = nenv * nsteps
        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nplayers, nh, nw, nc*nstack)
        nact = ac_space.n
        X = tf.placeholder(tf.uint8, ob_shape) #obs
        C = tf.placeholder(tf.int32, [nbatch, nplayers, 2])

        pis = []
        vfs = []

        with tf.variable_scope("model", reuse=reuse):
            x = tf.reshape(tf.cast(X, tf.float32)/255., [nbatch*nplayers, nh, nw, nc*nstack])
            crds = tf.cast(C, tf.float32) / map_size[0]
            h = conv(x, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
            h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
            h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
            h3 = conv_to_fc(h3)
            h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
            h5 = fc(h4, 'fc2', nh=512, init_scale=np.sqrt(2))
            h6 = fc(h5, 'fc3', nh=256, init_scale=np.sqrt(2)) # to give compatible network size
            h6 = tf.reshape(h6, [nbatch, nplayers, -1])
            h6 = tf.concat([h6, crds], axis=2)

            _reuse = False
            for i in range(nplayers):
                pi = fc(h6[:,i], 'pi', nact, act=tf.identity, reuse=_reuse)
                vf = fc(h6[:,i], 'v', 1, act=tf.identity, reuse=_reuse)
                pis.append(pi)
                vfs.append(vf)
                _reuse = True
            pi = tf.reshape(tf.concat(pis, axis=1), [nbatch*nplayers, -1])
            vf = tf.reshape(tf.concat(vfs, axis=1), [nbatch*nplayers, -1])

        v0 = vf
        a0 = sample(pi)

        self.init_map = []
        self.init_state = []


        def step(ob, coords, *_args, **_kwargs):
            a, v = sess.run([a0, v0], {X:ob, C:coords})
            a = [a[i:i+nplayers] for i in range(0, len(a), nplayers)]
            v = [v[i:i+nplayers] for i in range(0, len(v), nplayers)]
            return a, v, [], [] #dummy state and recon

        def value(ob, coords, *_args, **_kwargs):
            v = sess.run(v0, {X:ob, C:coords})
            v = [v[i:i+nplayers] for i in range(0, len(v), nplayers)]
            return v

        self.X = X
        self.C = C
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value
Esempio n. 4
0
def model3(X, nact, scope, reuse = False, layer_norm = False):
    with tf.variable_scope(scope, reuse = reuse):
        h = conv(tf.cast(X, tf.float32), 'c1', nf = 32, rf = 8, stride = 1, init_scale = np.sqrt(2))
        h2 = conv(h, 'c2', nf = 64, rf = 4, stride = 1, init_scale = np.sqrt(2))
        h3 = conv(h2, 'c3', nf = 64, rf = 3, stride = 1, init_scale = np.sqrt(2))
        # for pi
        h_pi = conv(h3, 'c_pi', nf = 2, rf = 1, stride = 1, init_scale = np.sqrt(2))
        h_pi_flat = conv_to_fc(h_pi)
        pi = fc(h_pi_flat, 'pi', nact, act = lambda x: x)
        # for v
        h_v = conv(h3, 'c_v1', nf = 1, rf = 1, stride = 1, init_scale = np.sqrt(2))
        h_v_flat = conv_to_fc(h_v)
        h_v_flat256 = fc(h_v_flat, 'c_v2', 256, init_scale = np.sqrt(2))
        vf = fc(h_v_flat256, 'v', 1, act = lambda x : tf.tanh(x))

        # filter out non-valid actions from pi
        valid = tf.reduce_max(tf.cast(X, tf.float32), axis = 1)
        valid_flat = tf.reshape(valid, [-1, nact])
        pi_fil = pi + (valid_flat - tf.ones(tf.shape(valid_flat))) * 1e32
    return pi_fil, vf[:, 0]
Esempio n. 5
0
def nature_cnn(unscaled_images, **conv_kwargs):
    """
    CNN from Nature paper.
    """
    scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
    activ = tf.nn.relu
    h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
                   **conv_kwargs))
    h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
    h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
    h3 = conv_to_fc(h3)
    return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
Esempio n. 6
0
def model2(X, nact, scope, reuse = False, layer_norm = False):
    # X should be nbatch * ncol * nrow * 2 (boolean)
    with tf.variable_scope(scope, reuse = reuse):
        h = conv(tf.cast(X, tf.float32), 'c1', nf = 32, rf = 8, stride = 1, init_scale = np.sqrt(2))
        # x = layers.layer_norm(x, scale = True, center = True)
        h2 = conv(h, 'c2', nf = 64, rf = 4, stride = 1, init_scale = np.sqrt(2))
        h3 = conv(h2, 'c3', nf = 64, rf = 3, stride = 1, init_scale = np.sqrt(2))
        h3 = conv_to_fc(h3)
        h4 = fc(h3, 'fc1', nh = 512, init_scale = np.sqrt(2))
        pi = fc(h4, 'pi', nact, act = lambda x : x)
        vf = fc(h4, 'v', 1, act = lambda x : tf.tanh(x))

        # filter out non-valid actions from pi
        valid = tf.reduce_max(tf.cast(X, tf.float32), axis = 1) 
        valid_flat = tf.reshape(valid, [-1, nact]) # this is the equavalent of "ind_flat_filter"
        pi_fil = pi + (valid_flat - tf.ones(tf.shape(valid_flat))) * 1e32
    return pi_fil, vf[:, 0]
Esempio n. 7
0
def model(X, nact, scope, reuse=False, layer_norm=False):
    with tf.variable_scope(scope, reuse=reuse):
        h = conv(tf.cast(X, tf.float32), 'c1', nf=32, rf=8, stride=1, init_scale=np.sqrt(2)) # TODO: when upgraded to batch run, add layer_norm to conv
        # x = layers.layer_norm(x, scale=True, center=True)
        h2 = conv(h, 'c2', nf=64, rf=4, stride=1, init_scale=np.sqrt(2)) 
        h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) 
        h3 = conv_to_fc(h3)
        h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
        pi = fc(h4, 'pi', nact, act=lambda x: x)
        vf = fc(h4, 'v', 1, act=lambda x: tf.tanh(x))

        pos = tf.reduce_max(X, axis = 1) # Comments by Fei: get 1 if the postive variable exists in any clauses, otherwise 0
        neg = tf.reduce_min(X, axis = 1) # Comments by Fei: get -1 if the negative variables exists in any clauses, otherwise 0
        ind = tf.concat([pos, neg], axis = 2) # Comments by Fei: get (1, -1) if this var is present, (1, 0) if only as positive, (0, -1) if only as negative
        ind_flat = tf.reshape(ind, [-1, nact]) # Comments by Fei: this is nbatch * nact, with 0 values labeling non_valid actions, 1 or -1 for other
        ind_flat_filter = tf.abs(tf.cast(ind_flat, tf.float32)) # Comments by Fei: this is nbatch * nact, with 0 values labeling non_valid actions, 1 for other
        #pi_fil = pi + (ind_flat_filter - tf.ones(tf.shape(ind_flat_filter))) * 1e32
        pi_fil = pi + (ind_flat_filter - tf.ones(tf.shape(ind_flat_filter))) * 1e32
    return pi_fil, vf[:, 0]
    def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
        nenv = nbatch // nsteps

        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nh, nw, nc)
        nact = ac_space.n
        X = tf.placeholder(tf.uint8, ob_shape)  # obs
        M = tf.placeholder(tf.float32, [nbatch])  # mask (done t-1)
        S = tf.placeholder(tf.float32, [nenv, nlstm * 2])  # states
        with tf.variable_scope("model", reuse=reuse):
            h = conv(tf.cast(X, tf.float32) / 255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
            h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
            h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
            h3 = conv_to_fc(h3)
            h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
            xs = batch_to_seq(h4, nenv, nsteps)
            ms = batch_to_seq(M, nenv, nsteps)
            h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
            h5 = seq_to_batch(h5)
            pi = fc(h5, 'pi', nact, act=lambda x: x)
            vf = fc(h5, 'v', 1, act=lambda x: x)

        self.pdtype = make_pdtype(ac_space)
        self.pd = self.pdtype.pdfromflat(pi)

        v0 = vf[:, 0]
        a0 = self.pd.sample()
        neglogp0 = self.pd.neglogp(a0)
        self.initial_state = np.zeros((nenv, nlstm * 2), dtype=np.float32)

        def step(ob, state, mask):
            return sess.run([a0, v0, snew, neglogp0], {X: ob, S: state, M: mask})

        def value(ob, state, mask):
            return sess.run(v0, {X: ob, S: state, M: mask})

        self.X = X
        self.M = M
        self.S = S
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value
Esempio n. 9
0
    def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=True):  #pylint: disable=W0613
        ob_shape = (nbatch, ) + ob_space.shape
        actdim = ac_space.shape[0]

        window_length = ob_space.shape[1] - 1

        X = tf.placeholder(tf.float32, ob_shape, name='Ob')  #obs

        #         with tf.variable_scope("model", reuse=reuse) as scope:

        #             # policy
        #             w0 = tf.slice(X, [0,0,0,0],[-1,-1,1,1], name='pi_sl0')
        #             x = tf.slice(X, [0,0,1,0],[-1,-1,-1,-1], name='pi_sl1')
        #             x = conv(tf.cast(x, tf.float32),'c1', fh=1,fw=4,nf=3, stride=1, init_scale=np.sqrt(2))
        #             # x = tf.layers.conv2d(
        #             #     inputs=x,
        #             #     filters=3,
        #             #     kernel_size=[1, 4],
        #             #     padding="valid",
        #             #     activation=tf.nn.relu)
        #             #(1, 3, 47, 3)

        #             x = conv(x, 'c2', fh=1, fw=window_length -3, nf=20, stride= window_length -3, init_scale=np.sqrt(2))
        #             # x = tf.layers.conv2d(
        #             #     inputs=x,
        #             #     filters=20,
        #             #     kernel_size=[1, window_length -3],
        #             #     padding="valid",
        #             #     strides=(1, window_length -3),
        #             #     activation=tf.nn.relu)

        #             x = tf.concat([x, w0], 3)

        #             x = conv(x, 'c3', fh=1, fw=1, nf=1, stride= 1, init_scale=np.sqrt(2))
        #             # x = tf.layers.conv2d(
        #             #     inputs=x,
        #             #     filters=1,
        #             #     kernel_size=[1, 1],
        #             #     padding="valid",
        #             #     strides=(1, 1),
        #             #     activation=tf.nn.relu)

        #             cash_bias = tf.zeros([x.shape[0],1,1,1], tf.float32)
        #             c = tf.concat([cash_bias, x], 1)

        #             v = conv_to_fc(x)

        #             # vf = fc(v, 'v',1)[:,0]

        #             f = tf.contrib.layers.flatten(c)
        #             eps = 10e20
        #             f = tf.clip_by_value(f, -eps, eps, 'clip1')
        #             # f = tf.Print(f, [f], "concatenate")
        #             pi = tf.nn.softmax(f)
        #             # pi = tf.Print(pi,[pi], 'pi ')

        #             # f = tf.nn.relu(f)
        #             vf = fc(v, 'v',1, act=tf.nn.relu)[:,0]

        #             # vf = tf.add(tf.ones(v.shape), v)

        #             # vf = fc(v, 'v',1)[:,0]

        #             # vf = tf.add(vf, tf.ones(vf.shape, tf.float32))

        #             logstd = tf.get_variable(name="logstd", shape=[1, actdim],
        #                 initializer=tf.zeros_initializer())
        #             eps = 80
        #             logstd = tf.clip_by_value(logstd, -eps, eps, 'clip_logstd')
        #             # logstd = tf.Print(logstd,[logstd], 'logstd ')
        with tf.variable_scope("model", reuse=reuse) as scope:
            w0 = tf.slice(X, [0, 0, 0, 0], [-1, -1, 1, 1])
            x = tf.slice(X, [0, 0, 1, 0], [-1, -1, -1, -1])

            # reuse when testing

            x = conv(tf.cast(x, tf.float32),
                     'c1',
                     fh=1,
                     fw=3,
                     nf=3,
                     stride=1,
                     init_scale=np.sqrt(2))

            x = conv(x,
                     'c2',
                     fh=1,
                     fw=window_length - 2,
                     nf=20,
                     stride=window_length - 2,
                     init_scale=np.sqrt(2))

            x = tf.concat([x, w0], 3)

            x = conv(x,
                     'c3',
                     fh=1,
                     fw=1,
                     nf=1,
                     stride=1,
                     init_scale=np.sqrt(2))

            cash_bias = tf.ones([x.shape[0], 1, 1, 1], tf.float32)
            c = tf.concat([cash_bias, x], 1)

            v = conv_to_fc(x)
            vf = fc(v, 'v', 1)[:, 0]

            f = tf.contrib.layers.flatten(c)

            pi = tf.nn.softmax(f)

            logstd = tf.get_variable(
                name="logstd",
                shape=[1, actdim],
                initializer=tf.truncated_normal_initializer())
            # logstd = tf.Print(logstd,[logstd], 'logstd ')
            eps = 50
            # logstd = tf.clip_by_value(logstd, -eps, eps, 'clip_logstd')

        pdparam = tf.concat([pi, pi * 0.0 + logstd], axis=1)

        self.pdtype = make_pdtype(ac_space)
        self.pd = self.pdtype.pdfromflat(pdparam)

        a0 = self.pd.sample()
        # a0 = tf.clip_by_value(a0, -eps, eps, 'clip2')
        a0 = tf.nn.softmax(a0)

        neglogp0 = self.pd.neglogp(a0)
        self.initial_state = None

        def step(ob, *_args, **_kwargs):
            a, v, neglogp, lst, p = sess.run([a0, vf, neglogp0, logstd, pi],
                                             {X: ob})

            # print ("logstd: "+ str(lst[0]))

            # print ("action: " + str(a))
            # print ("value: {}".format(v))
            # print ("neglogp: "+ str(neglogp))
            # print ("f:{}".format(f))
            return a, v, self.initial_state, neglogp, lst[0], p

        def value(ob, *_args, **_kwargs):
            v = sess.run(vf, {X: ob})
            # print ("vf: " + str(v))
            return v

        self.X = X
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value
Esempio n. 10
0
    def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack,
            nplayers, nlstm=256, reuse=False):
        nbatch = nenv * nsteps
        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nplayers, nh, nw, nc*nstack)
        nact = ac_space.n

        X = tf.placeholder(tf.uint8, ob_shape, name='X')
        M = tf.placeholder(tf.float32, [nbatch, nplayers], name='M')
        S = tf.placeholder(tf.float32, [nbatch, nlstm*2], name='S')

        pis = []
        vfs = []

        with tf.variable_scope("model", reuse=reuse):
            # tuck observation from all players at once
            x = tf.reshape(tf.cast(X, tf.float32)/255., [nbatch*nplayers, nh, nw, nc*nstack])
            h1 = conv( x, 'conv1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
            h2 = conv(h1, 'conv2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
            h3 = conv(h2, 'conv3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
            h3 = conv_to_fc(h3)
            h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))

            # shared memory:
            # instead of time-sequence, each rnn cell here
            # is responsible for "one player"
            xs = batch_to_seq(h4, nenv*nsteps, nplayers)
            ms = batch_to_seq( M, nenv*nsteps, nplayers)

            mem, snew = lnmem(xs, ms, S, 'lstm1', nh=nlstm)
            mem = tf.reshape(mem, [nbatch, nlstm*2])
            mem = fc(mem, 'fcmem', nh=256, init_scale=np.sqrt(2))
            #tf.summary.histogram('rnn_activation', mem)
            h4 = tf.reshape(h4, [nbatch, nplayers, -1])

            # compute pi, vaule for each agents

            _reuse = False
            for i in range(nplayers):
                h5 = fc(tf.concat([mem, h4[:,i]], axis=1), 'fc-pi', nh=512, init_scale=np.sqrt(2), reuse=_reuse)
                pi = fc(h5, 'pi', nact, act=tf.identity, reuse=_reuse)
                vf = fc(h5, 'v', 1, act=tf.identity, reuse=_reuse)
                pis.append(pi)
                vfs.append(vf)
                _reuse = True
            pi = tf.reshape(tf.concat(pis, axis=1), [nbatch*nplayers, -1])
            vf = tf.reshape(tf.concat(vfs, axis=1), [nbatch*nplayers, -1])

        v0 = vf
        a0 = sample(pi)

        self.init_map = []
        self.init_state = np.zeros((nbatch, nlstm*2), dtype=np.float32)

        def step(ob, state, mask):
            a, v, s = sess.run([a0, v0, snew], {X:ob, S:state, M:mask})
            a = [a[i:i+nplayers] for i in range(0, len(a), nplayers)]
            v = [v[i:i+nplayers] for i in range(0, len(v), nplayers)]
            return a, v, s, [] # dummy recon

        def value(ob, state, mask):
            v = sess.run(v0, {X:ob, S:state, M:mask})
            v = [v[i:i+nplayers] for i in range(0, len(v), nplayers)]
            return v

        self.X = X
        self.M = M
        self.S = S
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value
Esempio n. 11
0
    def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack,
            nplayers, map_size=[15, 15, 32], reuse=False):
        nbatch = nenv * nsteps
        nh, nw, nc = ob_space.shape
        ob_shape = (nbatch, nplayers, nh, nw, nc*nstack)
        nact = ac_space.n

        X = tf.placeholder(tf.uint8, ob_shape, name='X')
        MAP = tf.placeholder(tf.float32, [nbatch,] + map_size, name='MEM')
        C = tf.placeholder(tf.int32, [nbatch, nplayers, 2])

        pis = []
        vfs = []

        with tf.variable_scope("mem-var"):
            m = tf.get_variable("mem-var-%d" % nbatch, shape=MAP.get_shape(), trainable=False)

        with tf.variable_scope("model", reuse=reuse):
            # tuck observation from all players at once
            x = tf.reshape(tf.cast(X, tf.float32)/255., [nbatch*nplayers, nh, nw, nc*nstack])
            m = tf.assign(m, MAP * 0.9)

            h1 = conv( x, 'conv1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), act=swish)
            h2 = conv(h1, 'conv2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), act=swish)
            h3 = conv(h2, 'conv3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), act=swish)
            h3 = conv_to_fc(h3)
            h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2), act=swish)

            # shared memory:
            # instead of time-sequence, each rnn cell here
            # is responsible for "one player"

            # mem_size : half for key, half for value
            xs = batch_to_seq(h4, nenv*nsteps, nplayers)
            crds = batch_to_seq(C, nenv*nsteps, nplayers)
            context, rs, ws, map_new = nmap(xs, m, crds, 'mem', nplayers, n=map_size[1], feat=map_size[-1])
            h4 = tf.reshape(h4, [nbatch, nplayers, -1])

            context = tf.reshape(context, [nbatch, nplayers, -1])
            ws = tf.reshape(ws, [nbatch, nplayers, -1])
            crds = tf.cast(C, tf.float32)

            # compute pi, vaule for each agents

            _reuse = False
            for i in range(nplayers):
                h5 = fc(tf.concat([context[:, i], crds[:, i], rs, ws[:, i]], axis=1), 'fc-pi', nh=512, init_scale=np.sqrt(2), reuse=_reuse, act=swish)
                pi = fc(h5, 'pi', nact, act=tf.identity, reuse=_reuse)
                vf = fc(h5, 'v', 1, act=tf.identity, reuse=_reuse)
                pis.append(pi)
                vfs.append(vf)
                _reuse = True
            pi = tf.reshape(tf.concat(pis, axis=1), [nbatch*nplayers, -1])
            vf = tf.reshape(tf.concat(vfs, axis=1), [nbatch*nplayers, -1])

        v0 = vf
        a0 = sample(pi)
        a_max = tf.argmax(pi, 1)

        self.init_state = []
        self.init_map = np.zeros([nbatch,]+map_size, dtype=np.float32)

        def step(ob, maps, coords, training=True):
            pi_op = a0 if training else a_max
            a, v, m = sess.run([pi_op, v0, map_new], {X:ob, MAP:maps, C:coords})
            a = [a[i:i+nplayers] for i in range(0, len(a), nplayers)]
            v = [v[i:i+nplayers] for i in range(0, len(v), nplayers)]
            return a, v, [], m # dummy state

        def value(ob, maps, coords):
            v = sess.run(v0, {X:ob, MAP:maps, C:coords})
            v = [v[i:i+nplayers] for i in range(0, len(v), nplayers)]
            return v

        self.X = X
        self.C = C
        self.MAP = MAP # to meet the a2c.py protocol.
        self.pi = pi
        self.vf = vf
        self.step = step
        self.value = value