def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 lr_rate=0.01,
                 total_steps=50000,
                 scope="discriminator",
                 kfac_clip=0.001,
                 max_grad_norm=0.5):
        self.lr = Scheduler(v=lr_rate, nvalues=total_steps, schedule='linear')
        self.disc_type = disc_type
        if disc_type not in disc_types:
            assert False
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        nact = ac_space.n
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            input_shape = self.ob_shape + self.ac_shape
        elif disc_type == 'centralized':
            input_shape = self.all_ob_shape + self.all_ac_shape
        elif disc_type == 'single':
            input_shape = self.all_ob_shape + self.all_ac_shape
        else:
            assert False

        self.g = tf.placeholder(tf.float32, (None, input_shape))
        self.e = tf.placeholder(tf.float32, (None, input_shape))
        self.lr_rate = tf.placeholder(tf.float32, ())
        self.adv = tf.placeholder(tf.float32, ())

        num_outputs = len(ob_spaces) if disc_type == 'centralized' else 1

        logits = self.build_graph(tf.concat([self.g, self.e], axis=0),
                                  num_outputs,
                                  reuse=False)
        labels = tf.concat([
            tf.ones([tf.shape(self.g)[0], 1]),
            -tf.ones([tf.shape(self.e)[0], 1])
        ],
                           axis=0)

        g_logits = self.build_graph(self.g, num_outputs, reuse=True)
        e_logits = self.build_graph(self.e, num_outputs, reuse=True)

        self.g_loss = tf.reduce_mean(g_logits)
        self.e_loss = tf.reduce_mean(-e_logits)

        # self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #     logits=g_logits, labels=tf.zeros_like(g_logits)))
        # self.e_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #     logits=e_logits, labels=tf.ones_like(e_logits)))

        self.total_loss = logits * labels  # tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

        epsilon = tf.random_uniform([], 0.0, 1.0)
        ge = self.g * epsilon + self.e * (1 - epsilon)
        gel = self.build_graph(ge, num_outputs, reuse=True)
        ddd = tf.gradients(gel, [ge])
        ddd = tf.norm(ddd, axis=1)
        self.ddd = tf.reduce_mean(tf.square(ddd - 1.)) * 5

        sample_net = logits + tf.random_normal(tf.shape(logits))
        fisher_loss = -tf.reduce_mean(
            tf.pow(logits - tf.stop_gradient(sample_net), 2))

        self.reward_op = tf.sigmoid(g_logits)
        # self.reward_op = tf.nn.sigmoid_cross_entropy_with_logits(logits=g_logits, labels=tf.zeros_like(g_logits))

        self.var_list = self.get_trainable_variables()
        params = find_trainable_variables(self.scope)
        grads = tf.gradients(self.total_loss, params)

        # self.d_optim = tf.train.AdamOptimizer(self.lr_rate, beta1=0.5, beta2=0.9).minimize(self.total_loss, var_list=self.var_list)
        with tf.variable_scope(self.scope + '/d_optim'):
            d_optim = kfac.KfacOptimizer(learning_rate=self.lr_rate,
                                         clip_kl=kfac_clip,
                                         momentum=0.9,
                                         kfac_update=1,
                                         epsilon=0.01,
                                         stats_decay=0.99,
                                         async=0,
                                         cold_iter=10,
                                         max_grad_norm=max_grad_norm)
            update_stats_op = d_optim.compute_and_apply_stats(fisher_loss,
                                                              var_list=params)
            train_op, q_runner = d_optim.apply_gradients(
                list(zip(grads, params)))
            self.q_runner = q_runner

        self.g_optim = tf.train.AdamOptimizer(learning_rate=0.0005).minimize(
            self.ddd)
        self.d_optim = train_op
        self.saver = tf.train.Saver(self.get_variables())

        self.params_flat = self.get_trainable_variables()
    def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 state_only,
                 discount,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 lr_rate=0.01,
                 total_steps=50000,
                 scope="discriminator",
                 kfac_clip=0.001,
                 max_grad_norm=0.5,
                 l2_loss_ratio=0.01):
        self.lr = Scheduler(v=lr_rate, nvalues=total_steps, schedule='linear')
        self.disc_type = disc_type
        self.l2_loss_ratio = l2_loss_ratio
        if disc_type not in disc_types:
            assert False
        self.state_only = state_only
        self.gamma = discount
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        try:
            nact = ac_space.n
        except:
            nact = ac_space.shape[0]
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        try:
            self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        except:
            self.all_ac_shape = sum([ac.shape[0] for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            self.obs = tf.placeholder(tf.float32, (None, self.ob_shape))
            self.nobs = tf.placeholder(tf.float32, (None, self.ob_shape))
            self.act = tf.placeholder(tf.float32, (None, self.ac_shape))
            self.labels = tf.placeholder(tf.float32, (None, 1))
            self.lprobs = tf.placeholder(tf.float32, (None, 1))
        elif disc_type == 'decentralized-all':
            self.obs = tf.placeholder(tf.float32, (None, self.all_ob_shape))
            self.nobs = tf.placeholder(tf.float32, (None, self.all_ob_shape))
            self.act = tf.placeholder(tf.float32, (None, self.all_ac_shape))
            self.labels = tf.placeholder(tf.float32, (None, 1))
            self.lprobs = tf.placeholder(tf.float32, (None, 1))
        else:
            assert False

        self.lr_rate = tf.placeholder(tf.float32, ())

        with tf.variable_scope(self.scope):
            rew_input = self.obs
            if not self.state_only:
                rew_input = tf.concat([self.obs, self.act], axis=1)

            with tf.variable_scope('reward'):
                self.reward = self.relu_net(rew_input, dout=1)
                # self.reward = self.tanh_net(rew_input, dout=1)

            with tf.variable_scope('vfn'):
                self.value_fn_n = self.relu_net(self.nobs, dout=1)
                # self.value_fn_n = self.tanh_net(self.nobs, dout=1)
            with tf.variable_scope('vfn', reuse=True):
                self.value_fn = self.relu_net(self.obs, dout=1)
                # self.value_fn = self.tanh_net(self.obs, dout=1)

            log_q_tau = self.lprobs
            log_p_tau = self.reward + self.gamma * self.value_fn_n - self.value_fn
            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.discrim_output = tf.exp(log_p_tau - log_pq)

        self.total_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) +
                                          (1 - self.labels) *
                                          (log_q_tau - log_pq))
        self.var_list = self.get_trainable_variables()
        params = find_trainable_variables(self.scope)
        self.l2_loss = tf.add_n([tf.nn.l2_loss(v)
                                 for v in params]) * self.l2_loss_ratio
        self.total_loss += self.l2_loss

        grads = tf.gradients(self.total_loss, params)
        # fisher_loss = -self.total_loss
        # self.d_optim = tf.train.AdamOptimizer(self.lr_rate, beta1=0.5, beta2=0.9).minimize(self.total_loss, var_list=self.var_list)
        with tf.variable_scope(self.scope + '/d_optim'):
            # d_optim = kfac.KfacOptimizer(
            #     learning_rate=self.lr_rate, clip_kl=kfac_clip,
            #     momentum=0.9, kfac_update=1, epsilon=0.01,
            #     stats_decay=0.99, async=0, cold_iter=10,
            #     max_grad_norm=max_grad_norm)
            # update_stats_op = d_optim.compute_and_apply_stats(fisher_loss, var_list=params)
            # train_op, q_runner = d_optim.apply_gradients(list(zip(grads, params)))
            # self.q_runner = q_runner
            d_optim = tf.train.AdamOptimizer(learning_rate=self.lr_rate)
            train_op = d_optim.apply_gradients(list(zip(grads, params)))
        self.d_optim = train_op
        self.saver = tf.train.Saver(self.get_variables())

        self.params_flat = self.get_trainable_variables()
class Discriminator(object):
    def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 lr_rate=0.01,
                 total_steps=50000,
                 scope="discriminator",
                 kfac_clip=0.001,
                 max_grad_norm=0.5):
        self.lr = Scheduler(v=lr_rate, nvalues=total_steps, schedule='linear')
        self.disc_type = disc_type
        if disc_type not in disc_types:
            assert False
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        nact = ac_space.n
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            input_shape = self.ob_shape + self.ac_shape
        elif disc_type == 'centralized':
            input_shape = self.all_ob_shape + self.all_ac_shape
        elif disc_type == 'single':
            input_shape = self.all_ob_shape + self.all_ac_shape
        else:
            assert False

        self.g = tf.placeholder(tf.float32, (None, input_shape))
        self.e = tf.placeholder(tf.float32, (None, input_shape))
        self.lr_rate = tf.placeholder(tf.float32, ())
        self.adv = tf.placeholder(tf.float32, ())

        num_outputs = len(ob_spaces) if disc_type == 'centralized' else 1

        logits = self.build_graph(tf.concat([self.g, self.e], axis=0),
                                  num_outputs,
                                  reuse=False)
        labels = tf.concat([
            tf.ones([tf.shape(self.g)[0], 1]),
            -tf.ones([tf.shape(self.e)[0], 1])
        ],
                           axis=0)

        g_logits = self.build_graph(self.g, num_outputs, reuse=True)
        e_logits = self.build_graph(self.e, num_outputs, reuse=True)

        self.g_loss = tf.reduce_mean(g_logits)
        self.e_loss = tf.reduce_mean(-e_logits)

        # self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #     logits=g_logits, labels=tf.zeros_like(g_logits)))
        # self.e_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #     logits=e_logits, labels=tf.ones_like(e_logits)))

        self.total_loss = logits * labels  # tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

        epsilon = tf.random_uniform([], 0.0, 1.0)
        ge = self.g * epsilon + self.e * (1 - epsilon)
        gel = self.build_graph(ge, num_outputs, reuse=True)
        ddd = tf.gradients(gel, [ge])
        ddd = tf.norm(ddd, axis=1)
        self.ddd = tf.reduce_mean(tf.square(ddd - 1.)) * 5

        sample_net = logits + tf.random_normal(tf.shape(logits))
        fisher_loss = -tf.reduce_mean(
            tf.pow(logits - tf.stop_gradient(sample_net), 2))

        self.reward_op = tf.sigmoid(g_logits)
        # self.reward_op = tf.nn.sigmoid_cross_entropy_with_logits(logits=g_logits, labels=tf.zeros_like(g_logits))

        self.var_list = self.get_trainable_variables()
        params = find_trainable_variables(self.scope)
        grads = tf.gradients(self.total_loss, params)

        # self.d_optim = tf.train.AdamOptimizer(self.lr_rate, beta1=0.5, beta2=0.9).minimize(self.total_loss, var_list=self.var_list)
        with tf.variable_scope(self.scope + '/d_optim'):
            d_optim = kfac.KfacOptimizer(learning_rate=self.lr_rate,
                                         clip_kl=kfac_clip,
                                         momentum=0.9,
                                         kfac_update=1,
                                         epsilon=0.01,
                                         stats_decay=0.99,
                                         async=0,
                                         cold_iter=10,
                                         max_grad_norm=max_grad_norm)
            update_stats_op = d_optim.compute_and_apply_stats(fisher_loss,
                                                              var_list=params)
            train_op, q_runner = d_optim.apply_gradients(
                list(zip(grads, params)))
            self.q_runner = q_runner

        self.g_optim = tf.train.AdamOptimizer(learning_rate=0.0005).minimize(
            self.ddd)
        self.d_optim = train_op
        self.saver = tf.train.Saver(self.get_variables())

        self.params_flat = self.get_trainable_variables()
        # self.clip = [tf.assign(v, tf.clip_by_value(v, -0.05, 0.05)) for v in self.get_trainable_variables()]
        # self.clip = tf.group(*self.clip)

    def build_graph(self, x, num_outputs=1, reuse=False):
        with tf.variable_scope(self.scope):
            if reuse:
                tf.get_variable_scope().reuse_variables()
            p_h1 = fc(x, 'fc1', nh=self.hidden_size)
            p_h2 = fc(p_h1, 'fc2', nh=self.hidden_size)
            logits = fc(p_h2, 'out', nh=num_outputs, act=lambda x: x)
        return logits

    def get_variables(self):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)

    def get_reward(self, obs, acs):
        if len(obs.shape) == 1:
            obs = np.expand_dims(obs, 0)
        if len(acs.shape) == 1:
            acs = np.expand_dims(acs, 0)
        feed_dict = {self.g: np.concatenate([obs, acs], axis=1)}
        return self.sess.run(self.reward_op, feed_dict)

    def train(self, g_obs, g_acs, e_obs, e_acs):
        feed_dict = {
            self.g: np.concatenate([g_obs, g_acs], axis=1),
            self.e: np.concatenate([e_obs, e_acs], axis=1),
            self.lr_rate: self.lr.value()
        }
        loss, _ = self.sess.run([self.total_loss, self.d_optim], feed_dict)
        for _ in range(5):
            self.sess.run(self.g_optim, feed_dict)
        g_loss, e_loss = self.sess.run([self.g_loss, self.e_loss], feed_dict)
        return g_loss, e_loss, None, None

    def restore(self, path):
        print('restoring from:' + path)
        self.saver.restore(self.sess, path)

    def save(self, save_path):
        ps = self.sess.run(self.params_flat)
        joblib.dump(ps, save_path)

    def load(self, load_path):
        loaded_params = joblib.load(load_path)
        restores = []
        for p, loaded_p in zip(self.params_flat, loaded_params):
            restores.append(p.assign(loaded_p))
        self.sess.run(restores)
class Discriminator(object):
    def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 state_only,
                 discount,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 lr_rate=0.01,
                 total_steps=50000,
                 scope="discriminator",
                 kfac_clip=0.001,
                 max_grad_norm=0.5,
                 l2_loss_ratio=0.01):
        self.lr = Scheduler(v=lr_rate, nvalues=total_steps, schedule='linear')
        self.disc_type = disc_type
        self.l2_loss_ratio = l2_loss_ratio
        if disc_type not in disc_types:
            assert False
        self.state_only = state_only
        self.gamma = discount
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        try:
            nact = ac_space.n
        except:
            nact = ac_space.shape[0]
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        try:
            self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        except:
            self.all_ac_shape = sum([ac.shape[0] for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            self.obs = tf.placeholder(tf.float32, (None, self.ob_shape))
            self.nobs = tf.placeholder(tf.float32, (None, self.ob_shape))
            self.act = tf.placeholder(tf.float32, (None, self.ac_shape))
            self.labels = tf.placeholder(tf.float32, (None, 1))
            self.lprobs = tf.placeholder(tf.float32, (None, 1))
        elif disc_type == 'decentralized-all':
            self.obs = tf.placeholder(tf.float32, (None, self.all_ob_shape))
            self.nobs = tf.placeholder(tf.float32, (None, self.all_ob_shape))
            self.act = tf.placeholder(tf.float32, (None, self.all_ac_shape))
            self.labels = tf.placeholder(tf.float32, (None, 1))
            self.lprobs = tf.placeholder(tf.float32, (None, 1))
        else:
            assert False

        self.lr_rate = tf.placeholder(tf.float32, ())

        with tf.variable_scope(self.scope):
            rew_input = self.obs
            if not self.state_only:
                rew_input = tf.concat([self.obs, self.act], axis=1)

            with tf.variable_scope('reward'):
                self.reward = self.relu_net(rew_input, dout=1)
                # self.reward = self.tanh_net(rew_input, dout=1)

            with tf.variable_scope('vfn'):
                self.value_fn_n = self.relu_net(self.nobs, dout=1)
                # self.value_fn_n = self.tanh_net(self.nobs, dout=1)
            with tf.variable_scope('vfn', reuse=True):
                self.value_fn = self.relu_net(self.obs, dout=1)
                # self.value_fn = self.tanh_net(self.obs, dout=1)

            log_q_tau = self.lprobs
            log_p_tau = self.reward + self.gamma * self.value_fn_n - self.value_fn
            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.discrim_output = tf.exp(log_p_tau - log_pq)

        self.total_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) +
                                          (1 - self.labels) *
                                          (log_q_tau - log_pq))
        self.var_list = self.get_trainable_variables()
        params = find_trainable_variables(self.scope)
        self.l2_loss = tf.add_n([tf.nn.l2_loss(v)
                                 for v in params]) * self.l2_loss_ratio
        self.total_loss += self.l2_loss

        grads = tf.gradients(self.total_loss, params)
        # fisher_loss = -self.total_loss
        # self.d_optim = tf.train.AdamOptimizer(self.lr_rate, beta1=0.5, beta2=0.9).minimize(self.total_loss, var_list=self.var_list)
        with tf.variable_scope(self.scope + '/d_optim'):
            # d_optim = kfac.KfacOptimizer(
            #     learning_rate=self.lr_rate, clip_kl=kfac_clip,
            #     momentum=0.9, kfac_update=1, epsilon=0.01,
            #     stats_decay=0.99, async=0, cold_iter=10,
            #     max_grad_norm=max_grad_norm)
            # update_stats_op = d_optim.compute_and_apply_stats(fisher_loss, var_list=params)
            # train_op, q_runner = d_optim.apply_gradients(list(zip(grads, params)))
            # self.q_runner = q_runner
            d_optim = tf.train.AdamOptimizer(learning_rate=self.lr_rate)
            train_op = d_optim.apply_gradients(list(zip(grads, params)))
        self.d_optim = train_op
        self.saver = tf.train.Saver(self.get_variables())

        self.params_flat = self.get_trainable_variables()

    def relu_net(self, x, layers=2, dout=1, hidden_size=128):
        out = x
        for i in range(layers):
            out = relu_layer(out, dout=hidden_size, name='l%d' % i)
        out = linear(out, dout=dout, name='lfinal')
        return out

    def tanh_net(self, x, layers=2, dout=1, hidden_size=128):
        out = x
        for i in range(layers):
            out = tanh_layer(out, dout=hidden_size, name='l%d' % i)
        out = linear(out, dout=dout, name='lfinal')
        return out

    def get_variables(self):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)

    def get_reward(self, obs, acs, obs_next, path_probs, discrim_score=False):
        if len(obs.shape) == 1:
            obs = np.expand_dims(obs, 0)
        if len(acs.shape) == 1:
            acs = np.expand_dims(acs, 0)
        if discrim_score:
            feed_dict = {
                self.obs: obs,
                self.act: acs,
                self.nobs: obs_next,
                self.lprobs: path_probs
            }
            scores = self.sess.run(self.discrim_output, feed_dict)
            score = np.log(scores + 1e-20) - np.log(1 - scores + 1e-20)
        else:
            feed_dict = {self.obs: obs, self.act: acs}
            score = self.sess.run(self.reward, feed_dict)
        return score

    def train(self, g_obs, g_acs, g_nobs, g_probs, e_obs, e_acs, e_nobs,
              e_probs):
        labels = np.concatenate(
            (np.zeros([g_obs.shape[0], 1]), np.ones([e_obs.shape[0], 1])),
            axis=0)
        feed_dict = {
            self.obs: np.concatenate([g_obs, e_obs], axis=0),
            self.act: np.concatenate([g_acs, e_acs], axis=0),
            self.nobs: np.concatenate([g_nobs, e_nobs], axis=0),
            self.lprobs: np.concatenate([g_probs, e_probs], axis=0),
            self.labels: labels,
            self.lr_rate: self.lr.value()
        }
        loss, _ = self.sess.run([self.total_loss, self.d_optim], feed_dict)
        return loss

    def restore(self, path):
        print('restoring from:' + path)
        self.saver.restore(self.sess, path)

    def save(self, save_path):
        ps = self.sess.run(self.params_flat)
        joblib.dump(ps, save_path)

    def load(self, load_path):
        loaded_params = joblib.load(load_path)
        restores = []
        for p, loaded_p in zip(self.params_flat, loaded_params):
            restores.append(p.assign(loaded_p))
        self.sess.run(restores)
    def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 lr_rate=0.01,
                 total_steps=50000,
                 scope="discriminator",
                 kfac_clip=0.001,
                 max_grad_norm=0.5):
        self.lr = Scheduler(v=lr_rate, nvalues=total_steps, schedule='linear')
        self.disc_type = disc_type
        if disc_type not in disc_types:
            assert False
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        try:
            nact = ac_space.n
        except:
            nact = ac_space.shape[0]
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        try:
            self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        except:
            self.all_ac_shape = sum([ac.shape[0] for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            input_shape = self.ob_shape + self.all_ac_shape
        elif disc_type == 'decentralized-all':
            input_shape = self.all_ob_shape + self.all_ac_shape
        else:
            assert False

        self.g = tf.placeholder(tf.float32, (None, input_shape))
        self.e = tf.placeholder(tf.float32, (None, input_shape))
        self.lr_rate = tf.placeholder(tf.float32, ())
        self.adv = tf.placeholder(tf.float32, ())

        num_outputs = 1

        logits = self.build_graph(tf.concat([self.g, self.e], axis=0),
                                  num_outputs,
                                  reuse=False)
        labels = tf.concat([
            tf.zeros([tf.shape(self.g)[0], 1]),
            tf.ones([tf.shape(self.e)[0], 1])
        ],
                           axis=0)

        g_logits = self.build_graph(self.g, num_outputs, reuse=True)
        e_logits = self.build_graph(self.e, num_outputs, reuse=True)

        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=g_logits, labels=tf.zeros_like(g_logits)))
        self.e_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=e_logits, labels=tf.ones_like(e_logits)))

        self.total_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                    labels=labels))
        fisher_loss = -self.total_loss

        # self.reward_op = tf.sigmoid(g_logits) * 2.0 - 1
        # self.reward_op = tf.log(tf.sigmoid(g_logits) + 1e-10)
        # self.reward_op = tf.nn.sigmoid_cross_entropy_with_logits(logits=g_logits, labels=tf.zeros_like(g_logits))
        self.reward_op = tf.log(tf.sigmoid(g_logits) +
                                1e-10) - tf.log(1 - tf.sigmoid(g_logits) +
                                                1e-10)

        self.var_list = self.get_trainable_variables()
        params = find_trainable_variables(self.scope)
        grads = tf.gradients(self.total_loss, params)

        # self.d_optim = tf.train.AdamOptimizer(self.lr_rate, beta1=0.5, beta2=0.9).minimize(self.total_loss, var_list=self.var_list)
        with tf.variable_scope(self.scope + '/d_optim'):
            # d_optim = kfac.KfacOptimizer(
            #     learning_rate=self.lr_rate, clip_kl=kfac_clip,
            #     momentum=0.9, kfac_update=1, epsilon=0.01,
            #     stats_decay=0.99, async=0, cold_iter=10,
            #     max_grad_norm=max_grad_norm)
            # update_stats_op = d_optim.compute_and_apply_stats(fisher_loss, var_list=params)
            # train_op, q_runner = d_optim.apply_gradients(list(zip(grads, params)))
            # self.q_runner = q_runner
            d_optim = tf.train.AdamOptimizer(learning_rate=self.lr_rate)
            train_op = d_optim.apply_gradients(list(zip(grads, params)))

        self.d_optim = train_op
        self.saver = tf.train.Saver(self.get_variables())

        self.params_flat = self.get_trainable_variables()
Example #6
0
    def __init__(self,
                 policy,
                 ob_space,
                 ac_space,
                 nenvs,
                 total_timesteps,
                 nprocs=2,
                 nsteps=200,
                 nstack=1,
                 ent_coef=0.00,
                 vf_coef=0.5,
                 vf_fisher_coef=1.0,
                 lr=0.25,
                 max_grad_norm=0.5,
                 kfac_clip=0.001,
                 lrschedule='linear',
                 identical=None):
        config = tf.ConfigProto(allow_soft_placement=True,
                                intra_op_parallelism_threads=nprocs,
                                inter_op_parallelism_threads=nprocs)
        config.gpu_options.allow_growth = True
        self.sess = sess = tf.Session(config=config)
        nbatch = nenvs * nsteps
        self.num_agents = num_agents = len(ob_space)
        self.n_actions = [ac_space[k].n for k in range(self.num_agents)]
        if identical is None:
            identical = [False for _ in range(self.num_agents)]

        scale = [1 for _ in range(num_agents)]
        pointer = [i for i in range(num_agents)]
        h = 0
        for k in range(num_agents):
            if identical[k]:
                scale[h] += 1
            else:
                pointer[h] = k
                h = k
        pointer[h] = num_agents

        A, ADV, R, PG_LR = [], [], [], []
        for k in range(num_agents):
            if identical[k]:
                A.append(A[-1])
                ADV.append(ADV[-1])
                R.append(R[-1])
                PG_LR.append(PG_LR[-1])
            else:
                A.append(tf.placeholder(tf.int32, [nbatch * scale[k]]))
                ADV.append(tf.placeholder(tf.float32, [nbatch * scale[k]]))
                R.append(tf.placeholder(tf.float32, [nbatch * scale[k]]))
                PG_LR.append(tf.placeholder(tf.float32, []))

        pg_loss, entropy, vf_loss, train_loss = [], [], [], []
        self.model = step_model = []
        self.model2 = train_model = []
        self.pg_fisher = pg_fisher_loss = []
        self.logits = logits = []
        sample_net = []
        self.vf_fisher = vf_fisher_loss = []
        self.joint_fisher = joint_fisher_loss = []
        self.lld = lld = []
        self.log_pac = []

        for k in range(num_agents):
            if identical[k]:
                step_model.append(step_model[-1])
                train_model.append(train_model[-1])
            else:
                step_model.append(
                    policy(sess,
                           ob_space[k],
                           ac_space[k],
                           ob_space,
                           ac_space,
                           nenvs,
                           1,
                           nstack,
                           reuse=False,
                           name='%d' % k))
                train_model.append(
                    policy(sess,
                           ob_space[k],
                           ac_space[k],
                           ob_space,
                           ac_space,
                           nenvs * scale[k],
                           nsteps,
                           nstack,
                           reuse=True,
                           name='%d' % k))
            logpac = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=train_model[k].pi, labels=A[k])
            self.log_pac.append(-logpac)

            lld.append(tf.reduce_mean(logpac))
            logits.append(train_model[k].pi)

            pg_loss.append(tf.reduce_mean(ADV[k] * logpac))
            entropy.append(tf.reduce_mean(cat_entropy(train_model[k].pi)))
            pg_loss[k] = pg_loss[k] - ent_coef * entropy[k]
            vf_loss.append(
                tf.reduce_mean(mse(tf.squeeze(train_model[k].vf), R[k])))
            train_loss.append(pg_loss[k] + vf_coef * vf_loss[k])

            pg_fisher_loss.append(-tf.reduce_mean(logpac))
            sample_net.append(train_model[k].vf +
                              tf.random_normal(tf.shape(train_model[k].vf)))
            vf_fisher_loss.append(-vf_fisher_coef * tf.reduce_mean(
                tf.pow(train_model[k].vf - tf.stop_gradient(sample_net[k]),
                       2)))
            joint_fisher_loss.append(pg_fisher_loss[k] + vf_fisher_loss[k])

        self.policy_params = []
        self.value_params = []

        for k in range(num_agents):
            if identical[k]:
                self.policy_params.append(self.policy_params[-1])
                self.value_params.append(self.value_params[-1])
            else:
                self.policy_params.append(
                    find_trainable_variables("policy_%d" % k))
                self.value_params.append(
                    find_trainable_variables("value_%d" % k))
        self.params = params = [
            a + b for a, b in zip(self.policy_params, self.value_params)
        ]
        params_flat = []
        for k in range(num_agents):
            params_flat.extend(params[k])

        self.grads_check = grads = [
            tf.gradients(train_loss[k], params[k]) for k in range(num_agents)
        ]
        clone_grads = [
            tf.gradients(lld[k], params[k]) for k in range(num_agents)
        ]

        self.optim = optim = []
        self.clones = clones = []
        update_stats_op = []
        train_op, clone_op, q_runner = [], [], []

        for k in range(num_agents):
            if identical[k]:
                optim.append(optim[-1])
                train_op.append(train_op[-1])
                q_runner.append(q_runner[-1])
                clones.append(clones[-1])
                clone_op.append(clone_op[-1])
            else:
                with tf.variable_scope('optim_%d' % k):
                    optim.append(
                        kfac.KfacOptimizer(learning_rate=PG_LR[k],
                                           clip_kl=kfac_clip,
                                           momentum=0.9,
                                           kfac_update=1,
                                           epsilon=0.01,
                                           stats_decay=0.99,
                                           async_var=0,
                                           cold_iter=10,
                                           max_grad_norm=max_grad_norm))
                    update_stats_op.append(optim[k].compute_and_apply_stats(
                        joint_fisher_loss, var_list=params[k]))
                    train_op_, q_runner_ = optim[k].apply_gradients(
                        list(zip(grads[k], params[k])))
                    train_op.append(train_op_)
                    q_runner.append(q_runner_)

                with tf.variable_scope('clone_%d' % k):
                    clones.append(
                        kfac.KfacOptimizer(learning_rate=PG_LR[k],
                                           clip_kl=kfac_clip,
                                           momentum=0.9,
                                           kfac_update=1,
                                           epsilon=0.01,
                                           stats_decay=0.99,
                                           async_var=0,
                                           cold_iter=10,
                                           max_grad_norm=max_grad_norm))
                    update_stats_op.append(clones[k].compute_and_apply_stats(
                        pg_fisher_loss[k], var_list=self.policy_params[k]))
                    clone_op_, q_runner_ = clones[k].apply_gradients(
                        list(zip(clone_grads[k], self.policy_params[k])))
                    clone_op.append(clone_op_)

        update_stats_op = tf.group(*update_stats_op)
        train_ops = train_op
        clone_ops = clone_op
        train_op = tf.group(*train_op)
        clone_op = tf.group(*clone_op)

        self.q_runner = q_runner
        self.lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
        self.clone_lr = Scheduler(v=lr,
                                  nvalues=total_timesteps,
                                  schedule=lrschedule)

        def train(obs, states, rewards, masks, actions, values):
            advs = [rewards[k] - values[k] for k in range(num_agents)]
            for step in range(len(obs)):
                cur_lr = self.lr.value()

            ob = np.concatenate(obs, axis=1)

            td_map = {}
            for k in range(num_agents):
                if identical[k]:
                    continue
                new_map = {}
                if num_agents > 1:
                    action_v = []
                    for j in range(k, pointer[k]):
                        action_v.append(
                            np.concatenate([
                                multionehot(actions[i], self.n_actions[i])
                                for i in range(num_agents) if i != k
                            ],
                                           axis=1))
                    action_v = np.concatenate(action_v, axis=0)
                    new_map.update({train_model[k].A_v: action_v})
                    td_map.update({train_model[k].A_v: action_v})

                new_map.update({
                    train_model[k].X:
                    np.concatenate([obs[j] for j in range(k, pointer[k])],
                                   axis=0),
                    train_model[k].X_v:
                    np.concatenate([ob.copy() for j in range(k, pointer[k])],
                                   axis=0),
                    A[k]:
                    np.concatenate([actions[j] for j in range(k, pointer[k])],
                                   axis=0),
                    ADV[k]:
                    np.concatenate([advs[j] for j in range(k, pointer[k])],
                                   axis=0),
                    R[k]:
                    np.concatenate([rewards[j] for j in range(k, pointer[k])],
                                   axis=0),
                    PG_LR[k]:
                    cur_lr / float(scale[k])
                })
                sess.run(train_ops[k], feed_dict=new_map)
                td_map.update(new_map)

                if states[k] != []:
                    td_map[train_model[k].S] = states
                    td_map[train_model[k].M] = masks

            policy_loss, value_loss, policy_entropy = sess.run(
                [pg_loss, vf_loss, entropy], td_map)
            return policy_loss, value_loss, policy_entropy

        def clone(obs, actions):
            td_map = {}
            cur_lr = self.clone_lr.value()
            for k in range(num_agents):
                if identical[k]:
                    continue
                new_map = {}
                new_map.update({
                    train_model[k].X:
                    np.concatenate([obs[j] for j in range(k, pointer[k])],
                                   axis=0),
                    A[k]:
                    np.concatenate([actions[j] for j in range(k, pointer[k])],
                                   axis=0),
                    PG_LR[k]:
                    cur_lr / float(scale[k])
                })
                sess.run(clone_ops[k], feed_dict=new_map)
                td_map.update(new_map)
            lld_loss = sess.run([lld], td_map)
            return lld_loss

        def get_log_action_prob(obs, actions):
            action_prob = []
            for k in range(num_agents):
                if identical[k]:
                    continue
                new_map = {
                    train_model[k].X:
                    np.concatenate([obs[j] for j in range(k, pointer[k])],
                                   axis=0),
                    A[k]:
                    np.concatenate([actions[j] for j in range(k, pointer[k])],
                                   axis=0)
                }
                log_pac = sess.run(self.log_pac[k], feed_dict=new_map)
                if scale[k] == 1:
                    action_prob.append(log_pac)
                else:
                    log_pac = np.split(log_pac, scale[k], axis=0)
                    action_prob += log_pac
            return action_prob

        self.get_log_action_prob = get_log_action_prob

        def get_log_action_prob_step(obs, actions):
            action_prob = []
            for k in range(num_agents):
                action_prob.append(step_model[k].step_log_prob(
                    obs[k], actions[k]))
            return action_prob

        self.get_log_action_prob_step = get_log_action_prob_step

        def save(save_path):
            ps = sess.run(params_flat)
            joblib.dump(ps, save_path)

        def load(load_path):
            loaded_params = joblib.load(load_path)
            restores = []
            for p, loaded_p in zip(params_flat, loaded_params):
                restores.append(p.assign(loaded_p))
            sess.run(restores)

        self.train = train
        self.clone = clone
        self.save = save
        self.load = load
        self.train_model = train_model
        self.step_model = step_model

        def step(ob, av, *_args, **_kwargs):
            a, v, s = [], [], []
            obs = np.concatenate(ob, axis=1)
            for k in range(num_agents):
                a_v = np.concatenate([
                    multionehot(av[i], self.n_actions[i])
                    for i in range(num_agents) if i != k
                ],
                                     axis=1)
                a_, v_, s_ = step_model[k].step(ob[k], obs, a_v)
                a.append(a_)
                v.append(v_)
                s.append(s_)
            return a, v, s

        self.step = step

        def value(obs, av):
            v = []
            ob = np.concatenate(obs, axis=1)
            for k in range(num_agents):
                a_v = np.concatenate([
                    multionehot(av[i], self.n_actions[i])
                    for i in range(num_agents) if i != k
                ],
                                     axis=1)
                v_ = step_model[k].value(ob, a_v)
                v.append(v_)
            return v

        self.value = value
        self.initial_state = [
            step_model[k].initial_state for k in range(num_agents)
        ]
    def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 gp_coef=5,
                 lr_rate=5e-4,
                 total_steps=50000,
                 scope="discriminator"):
        self.lr = Scheduler(v=lr_rate,
                            nvalues=total_steps * 20,
                            schedule='linear')
        self.disc_type = disc_type
        if disc_type not in disc_types:
            assert False
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        nact = ac_space.n
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            input_shape = self.all_ob_shape + self.ac_shape
        elif disc_type == 'centralized':
            input_shape = self.all_ob_shape + self.all_ac_shape
        elif disc_type == 'single':
            input_shape = self.all_ob_shape + self.all_ac_shape
        else:
            assert False

        self.g = tf.placeholder(tf.float32, (None, input_shape))
        self.e = tf.placeholder(tf.float32, (None, input_shape))
        self.lr_rate = tf.placeholder(tf.float32, ())

        num_outputs = len(ob_spaces) if disc_type == 'centralized' else 1
        self.bias = tf.get_variable(name=scope + '_bias',
                                    shape=(num_outputs, ),
                                    initializer=tf.zeros_initializer,
                                    trainable=False)
        self.bias_ph = tf.placeholder(tf.float32, (num_outputs, ))
        self.update_bias = tf.assign(self.bias,
                                     self.bias_ph * 0.01 + self.bias * 0.99)

        generator_logits = self.build_graph(self.g, num_outputs, reuse=False)
        expert_logits = self.build_graph(self.e, num_outputs, reuse=True)

        self.generator_loss = tf.reduce_mean(generator_logits, axis=0)
        self.expert_loss = tf.reduce_mean(expert_logits, axis=0)

        ddg = tf.gradients(generator_logits, [self.g])
        ddg = tf.sqrt(tf.reduce_sum(tf.square(ddg[0]), axis=1))
        self.ddg = tf.reduce_mean(tf.square(ddg - 1.))

        dde = tf.gradients(expert_logits, [self.e])
        dde = tf.sqrt(tf.reduce_sum(tf.square(dde[0]), axis=1))
        self.dde = tf.reduce_mean(tf.square(dde - 1.))

        epsilon = tf.random_uniform([], 0.0, 1.0)
        ge = self.g * epsilon + self.e * (1 - epsilon)
        gel = self.build_graph(ge, num_outputs, reuse=True)
        ddd = tf.gradients(gel, [ge])
        ddd = tf.norm(ddd, axis=1)
        self.ddd = tf.reduce_mean(tf.square(ddd - 1.))

        self.total_loss = self.generator_loss - self.expert_loss + gp_coef * self.ddd  #(self.ddg + self.dde)
        self.reward_op = generator_logits

        self.var_list = self.get_trainable_variables()
        self.d_optim = tf.train.AdamOptimizer(self.lr_rate,
                                              beta1=0.5,
                                              beta2=0.9).minimize(
                                                  self.total_loss,
                                                  var_list=self.var_list)
        self.saver = tf.train.Saver(self.get_variables())
class Discriminator(object):
    def __init__(self,
                 sess,
                 ob_spaces,
                 ac_spaces,
                 nstack,
                 index,
                 disc_type='decentralized',
                 hidden_size=128,
                 gp_coef=5,
                 lr_rate=5e-4,
                 total_steps=50000,
                 scope="discriminator"):
        self.lr = Scheduler(v=lr_rate,
                            nvalues=total_steps * 20,
                            schedule='linear')
        self.disc_type = disc_type
        if disc_type not in disc_types:
            assert False
        self.scope = scope
        self.index = index
        self.sess = sess
        ob_space = ob_spaces[index]
        ac_space = ac_spaces[index]
        self.ob_shape = ob_space.shape[0] * nstack
        nact = ac_space.n
        self.ac_shape = nact * nstack
        self.all_ob_shape = sum([obs.shape[0] for obs in ob_spaces]) * nstack
        self.all_ac_shape = sum([ac.n for ac in ac_spaces]) * nstack
        self.hidden_size = hidden_size

        if disc_type == 'decentralized':
            input_shape = self.all_ob_shape + self.ac_shape
        elif disc_type == 'centralized':
            input_shape = self.all_ob_shape + self.all_ac_shape
        elif disc_type == 'single':
            input_shape = self.all_ob_shape + self.all_ac_shape
        else:
            assert False

        self.g = tf.placeholder(tf.float32, (None, input_shape))
        self.e = tf.placeholder(tf.float32, (None, input_shape))
        self.lr_rate = tf.placeholder(tf.float32, ())

        num_outputs = len(ob_spaces) if disc_type == 'centralized' else 1
        self.bias = tf.get_variable(name=scope + '_bias',
                                    shape=(num_outputs, ),
                                    initializer=tf.zeros_initializer,
                                    trainable=False)
        self.bias_ph = tf.placeholder(tf.float32, (num_outputs, ))
        self.update_bias = tf.assign(self.bias,
                                     self.bias_ph * 0.01 + self.bias * 0.99)

        generator_logits = self.build_graph(self.g, num_outputs, reuse=False)
        expert_logits = self.build_graph(self.e, num_outputs, reuse=True)

        self.generator_loss = tf.reduce_mean(generator_logits, axis=0)
        self.expert_loss = tf.reduce_mean(expert_logits, axis=0)

        ddg = tf.gradients(generator_logits, [self.g])
        ddg = tf.sqrt(tf.reduce_sum(tf.square(ddg[0]), axis=1))
        self.ddg = tf.reduce_mean(tf.square(ddg - 1.))

        dde = tf.gradients(expert_logits, [self.e])
        dde = tf.sqrt(tf.reduce_sum(tf.square(dde[0]), axis=1))
        self.dde = tf.reduce_mean(tf.square(dde - 1.))

        epsilon = tf.random_uniform([], 0.0, 1.0)
        ge = self.g * epsilon + self.e * (1 - epsilon)
        gel = self.build_graph(ge, num_outputs, reuse=True)
        ddd = tf.gradients(gel, [ge])
        ddd = tf.norm(ddd, axis=1)
        self.ddd = tf.reduce_mean(tf.square(ddd - 1.))

        self.total_loss = self.generator_loss - self.expert_loss + gp_coef * self.ddd  #(self.ddg + self.dde)
        self.reward_op = generator_logits

        self.var_list = self.get_trainable_variables()
        self.d_optim = tf.train.AdamOptimizer(self.lr_rate,
                                              beta1=0.5,
                                              beta2=0.9).minimize(
                                                  self.total_loss,
                                                  var_list=self.var_list)
        self.saver = tf.train.Saver(self.get_variables())

    def build_graph(self, x, num_outputs=1, reuse=False):
        with tf.variable_scope(self.scope):
            if reuse:
                tf.get_variable_scope().reuse_variables()
            p_h1 = fc(x, 'fc1', nh=self.hidden_size)
            p_h2 = fc(p_h1, 'fc2', nh=self.hidden_size)
            p_h3 = fc(p_h2, 'fc3', nh=self.hidden_size)
            logits = fc(p_h3, 'out', nh=num_outputs, act=lambda x: x)
            logits -= self.bias
        return logits

    def get_variables(self):
        return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)

    def get_trainable_variables(self):
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)

    def get_reward(self, all_obs, acs):
        if len(all_obs.shape) == 1:
            all_obs = np.expand_dims(all_obs, 0)
        if len(acs.shape) == 1:
            acs = np.expand_dims(acs, 0)
        feed_dict = {self.g: np.concatenate([all_obs, acs], axis=1)}
        return self.sess.run(self.reward_op, feed_dict)

    def train(self, g_all_obs, g_acs, e_all_obs, e_acs):
        feed_dict = {
            self.g: np.concatenate([g_all_obs, g_acs], axis=1),
            self.e: np.concatenate([e_all_obs, e_acs], axis=1),
            self.lr_rate: self.lr.value()
        }
        gl, el, _ = self.sess.run(
            [self.generator_loss, self.expert_loss, self.d_optim], feed_dict)
        # self.sess.run(self.update_bias, feed_dict={self.bias_ph: (gl + el) / 2.0})
        return self.sess.run(
            [self.generator_loss, self.expert_loss, self.ddg, self.dde],
            feed_dict)

    def restore(self, path):
        print('restoring from:' + path)
        self.saver.restore(self.sess, path)