Ejemplo n.º 1
0
    def __init__(self, env, policy_net, summary_writer, saver=None):

        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver
        self.env = env

        # Correct the path
        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))
        print('[PM] checkpoint_path: {}'.format(self.checkpoint_path))

        # Local policy net
        with tf.variable_scope("policy_eval"):
            if LSTM_POLICY:
                self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs)
            else:
                self.policy_net = PolicyEstimator(policy_net.num_outputs)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))
Ejemplo n.º 2
0
  def __init__(self, env, policy_net, summary_writer, saver=None):


    self.global_policy_net = policy_net
    self.summary_writer = summary_writer
    self.saver = saver
    #self.sp = StateProcessor()

    self.env = CDLL('./PythonAccessToSim.so')
    self.env.step.restype = step_result
    self.env.send_command.restype = c_int
    self.env.initialize.restype = c_int
    self.env.recieve_state_gui.restype = step_result

    self.actions = list(range(0,3*Num_Targets))

    self.checkpoint_path = os.path.abspath(os.path.join(summary_writer.get_logdir(), "./checkpoints/model"))
    print(self.checkpoint_path)
    # Local policy net
    with tf.variable_scope("policy_eval"):
      self.policy_net = PolicyEstimator(policy_net.num_outputs)

    # Op to copy params from global policy/value net parameters
    self.copy_params_op = make_copy_params_op(
      tf.contrib.slim.get_variables(scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
      tf.contrib.slim.get_variables(scope="policy_eval", collection=tf.GraphKeys.TRAINABLE_VARIABLES))
    def testPredict(self):
        env = make_env()
        sp = StateProcessor()
        estimator = PolicyEstimator(len(VALID_ACTIONS))

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())

            # Generate a state
            state = sp.process(env.reset())
            processed_state = atari_helpers.atari_make_initial_state(state)
            processed_states = np.array([processed_state])

            # Run feeds
            feed_dict = {
                estimator.states: processed_states,
                estimator.targets: [1.0],
                estimator.actions: [1]
            }

            loss = sess.run(estimator.loss, feed_dict)
            pred = sess.run(estimator.predictions, feed_dict)

            # Assertions
            self.assertTrue(loss != 0.0)
            self.assertEqual(pred["probs"].shape, (1, len(VALID_ACTIONS)))
            self.assertEqual(pred["logits"].shape, (1, len(VALID_ACTIONS)))
    def testGradient(self):
        env = make_env()
        sp = StateProcessor()
        estimator = PolicyEstimator(len(VALID_ACTIONS))
        grads = [g for g, _ in estimator.grads_and_vars]

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())

            # Generate a state
            state = sp.process(env.reset())
            processed_state = atari_helpers.atari_make_initial_state(state)
            processed_states = np.array([processed_state])

            # Run feeds to get gradients
            feed_dict = {
                estimator.states: processed_states,
                estimator.targets: [1.0],
                estimator.actions: [1]
            }

            grads_ = sess.run(grads, feed_dict)

            # Apply calculated gradients
            grad_feed_dict = {k: v for k, v in zip(grads, grads_)}
            _ = sess.run(estimator.train_op, grad_feed_dict)
Ejemplo n.º 5
0
    def __init__(self, env, policy_net, summary_writer, saver=None):

        self.video_dir = os.path.join(summary_writer.get_logdir(), "../videos")
        self.video_dir = os.path.abspath(self.video_dir)

        self.env = Monitor(env, directory=self.video_dir, video_callable=lambda x: True, resume=True)
        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver
        self.sp = StateProcessor()

        self.checkpoint_path = os.path.abspath(os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))

        try:
            os.makedirs(self.video_dir)
        except FileExistsError:
            pass

        # Local policy net
        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(policy_net.num_outputs)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(scope="policy_eval", collection=tf.GraphKeys.TRAINABLE_VARIABLES))
    def __init__(self, env, policy_net, summary_writer, saver=None):

        self.video_dir = os.path.join(summary_writer.get_logdir(), "../videos")
        self.video_dir = os.path.abspath(self.video_dir)

        self.env = Monitor(env,
                           directory=self.video_dir,
                           video_callable=lambda x: True,
                           resume=True)
        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver

        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))

        try:
            os.makedirs(self.video_dir)
        except OSError as e:
            # FileExistsError was added in Python 3.3; You can't use FileExistsError.
            # https: // stackoverflow.com / questions / 20790580 / python - specifically - handle - file - exists - exception
            #  Use errno.EEXIST
            pass

        # Local policy net
        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(policy_net.num_outputs)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))
Ejemplo n.º 7
0
  def __init__(self, name, env, policy_net, value_net, global_counter, discount_factor=0.99, summary_writer=None, max_global_steps=None):
    self.name = name
    self.discount_factor = discount_factor
    self.max_global_steps = max_global_steps
    self.global_step = tf.contrib.framework.get_global_step()
    self.global_policy_net = policy_net
    self.global_value_net = value_net
    self.global_counter = global_counter
    self.local_counter = itertools.count()
    self.summary_writer = summary_writer
    self.env = env

    # Create local policy/value nets that are not updated asynchronously
    with tf.variable_scope(name):
      self.policy_net = PolicyEstimator(policy_net.num_outputs)
      self.value_net = ValueEstimator(reuse=True)

    # Op to copy params from global policy/valuenets
    self.copy_params_op = make_copy_params_op(
      tf.contrib.slim.get_variables(scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
      tf.contrib.slim.get_variables(scope=self.name+'/', collection=tf.GraphKeys.TRAINABLE_VARIABLES))

    self.vnet_train_op = make_train_op(self.value_net, self.global_value_net)
    self.pnet_train_op = make_train_op(self.policy_net, self.global_policy_net)

    self.state = None
Ejemplo n.º 8
0
def main():
    VALID_ACTIONS = list(range(10))
    MODEL_DIR = "data/train"
    CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")

    env = simulator.Task(debug_flag=True, test_flag=True, state_blink=True, state_inaccurate=True)

    with tf.variable_scope("global") as vs:
        policy_net = PolicyEstimator(num_outputs=len(VALID_ACTIONS))

    saver = tf.train.Saver(keep_checkpoint_every_n_hours=2.0, max_to_keep=10)
    pe = PolicyMonitor(
        env=env,
        policy_net=policy_net)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()

    # Load a previous checkpoint if it exists
    latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
    if latest_checkpoint:
        print("Loading model checkpoint: {}".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
    else:
        print("Fail to load model")
        return

    with sess.as_default(), sess.graph.as_default():
        _ = sess.run(pe.copy_params_op)

    done = False
    state = np.stack([pe.env.reset().reshape(93,93)] * 4, axis=2)
    rnn_state = pe.policy_net.state_init
    total_reward = 0.0

    num_episode = 0
    while(1):
        action_probs, next_rnn_state = pe._policy_net_predict(state, rnn_state, sess)
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        next_state, reward, done = pe.env.step(action)
        next_state = np.append(state[:,:,1:], np.expand_dims(next_state.reshape(93,93), 2), axis=2)
        total_reward += reward
        state = next_state
        rnn_state = next_rnn_state

        if done == True: # done was True in latest transition; we have already stored that
            done = False
            state = np.stack([pe.env.reset().reshape(93,93)] * 4, axis=2)
            rnn_state = pe.policy_net.state_init
            total_reward = 0.0
            num_episode = num_episode + 1

        cv2.waitKey(30)

        if num_episode >= 10:
            break;
Ejemplo n.º 9
0
  def setUp(self):
    super(PolicyMonitorTest, self).setUp()

    self.env = make_env()
    self.global_step = tf.Variable(0, name="global_step", trainable=False)
    self.summary_writer = tf.train.SummaryWriter(tempfile.mkdtemp())

    with tf.variable_scope("global") as vs:
      self.global_policy_net = PolicyEstimator(len(VALID_ACTIONS))
      self.global_value_net = ValueEstimator(reuse=True)
Ejemplo n.º 10
0
    def __init__(self, env, policy_net):
        self.env = env
        self.global_policy_net = policy_net

        # Local policy net
        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(policy_net.num_outputs)

        self.copy_params_op = make_copy_params_op(
        tf.contrib.slim.get_variables(scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
        tf.contrib.slim.get_variables(scope="policy_eval", collection=tf.GraphKeys.TRAINABLE_VARIABLES))
Ejemplo n.º 11
0
    def setUp(self):
        super(WorkerTest, self).setUp()

        self.env = make_env()
        self.discount_factor = 0.99
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        self.global_counter = itertools.count()
        self.sp = StateProcessor()

        with tf.variable_scope("global") as vs:
            self.global_policy_net = PolicyEstimator(len(VALID_ACTIONS))
            self.global_value_net = ValueEstimator(reuse=True)
Ejemplo n.º 12
0
    def __init__(self,
                 name,
                 env,
                 env_id,
                 curriculum,
                 policy_nets,
                 value_nets,
                 shared_final_layer,
                 global_counter,
                 discount_factor=0.99,
                 summary_writer=None,
                 max_global_steps=None):
        self.name = name
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_steps
        self.global_step = tf.contrib.framework.get_global_step()
        self.global_policy_nets = policy_nets
        self.global_value_nets = value_nets
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        self.env = env
        self.env_id = env_id
        self.curriculum = curriculum
        self.shared_final_layer = shared_final_layer

        # Create local policy/value nets that are not updated asynchronously
        with tf.variable_scope(name):
            self.policy_net = PolicyEstimator(policy_nets[0].num_outputs,
                                              state_dims=env.get_state_size())
            self.value_net = ValueEstimator(reuse=True,
                                            state_dims=env.get_state_size())

        # Op to copy params from global policy/valuenets
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global_{}".format(env_id),
                collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.vnet_train_op = self.make_train_op(self.value_net,
                                                self.global_value_nets)
        self.pnet_train_op = self.make_train_op(self.policy_net,
                                                self.global_policy_nets)
        if self.shared_final_layer:
            # create ops to train the final layers of all other agents
            self.policy_layer_train_ops = self.make_final_layer_train_ops(
                self.policy_net, self.global_policy_nets, 'policy')
            self.value_layer_train_ops = self.make_final_layer_train_ops(
                self.value_net, self.global_value_nets, 'value')

        self.state = None
        self.epochs = 0
Ejemplo n.º 13
0
    def setUp(self):
        super(WorkerTest, self).setUp()

        self.env = make_env()
        self.discount_factor = 0.99
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        self.global_counter = itertools.count()
        self.sp = StateProcessor()

        with tf.variable_scope("global") as vs:
            print("length of the actions: {}".format(len(VALID_ACTIONS)))
            self.global_policy_net = PolicyEstimator(len(VALID_ACTIONS))
            self.global_value_net = ValueEstimator(reuse=True)

        def testPolicyNetPredict(self):
            w = Worker(name="test",
                       env=make_env(),
                       policy_net=self.global_policy_net,
                       value_net=self.global_value_net,
                       global_counter=self.global_counter,
                       discount_factor=self.discount_factor)

            with self.test_session() as sess:
                sess.run(tf.initialize_all_variables())
                state = self.sp.process(self.env.reset())
                processed_state = atari_helpers.atari_make_initial_state(state)
                action_values = w._policy_net_predict(processed_state, sess)
                self.assertEqual(action_values.shape, (4, ))

        def testValueNetPredict(self):
            w = Worker(name="test",
                       env=make_env(),
                       policy_net=self.global_policy_net,
                       value_net=self.global_value_net,
                       global_counter=self.global_counter,
                       discount_factor=self.discount_factor)

            with self.test_session() as sess:
                sess.run(tf.initialize_all_variables())
                state = self.sp.process(self.env.reset())
                processed_state = atari_helpers.atari_make_initial_state(state)
                w.state = processed_state
                transitions, local_t, global_t = w.run_n_steps(10, sess)
                policy_net_loss, value_net_loss, policy_net_summaries, value_net_summaries = w.update(
                    transitions, sess)

            self.assertEqual(len(transitions), 10)
            self.assertIsNotNone(policy_net_loss)
            self.assertIsNotNone(value_net_loss)
            self.assertIsNotNone(policy_net_summaries)
            self.assertIsNotNone(value_net_summaries)
Ejemplo n.º 14
0
  def __init__(self, env, policy_net, task):

    self.env = env
    self.global_policy_net = policy_net
    self.task = task

    # Local policy net
    with tf.variable_scope("policy_visualization"):
      self.policy_net = PolicyEstimator(policy_net.num_outputs, state_dims=self.env.get_state_size())

    # Op to copy params from global policy/value net parameters
    self.copy_params_op = make_copy_params_op(
      tf.contrib.slim.get_variables(scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
      tf.contrib.slim.get_variables(scope="policy_visualization", collection=tf.GraphKeys.TRAINABLE_VARIABLES))
Ejemplo n.º 15
0
    def __init__(self,
                 name,
                 env,
                 policy_net,
                 value_net,
                 global_counter,
                 discount_factor=0.99,
                 summary_writer=None,
                 max_global_steps=None):
        self.name = name
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_steps
        self.global_step = tf.contrib.framework.get_global_step()
        self.global_policy_net = policy_net
        self.global_value_net = value_net
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        #self.sp = StateProcessor()
        self.summary_writer = summary_writer

        self.env = CDLL('./PythonAccessToSim.so')
        self.env.step.restype = step_result
        self.env.send_command.restype = c_int
        self.env.initialize.restype = c_int
        self.env.recieve_state_gui.restype = step_result
        self.actions = list(range(0, 3 * Num_Targets))

        # Create local policy/value nets that are not updated asynchronously
        with tf.variable_scope(name):
            self.policy_net = PolicyEstimator(policy_net.num_outputs)
            self.value_net = ValueEstimator(reuse=True)

        # Op to copy params from global policy/valuenets
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.vnet_train_op = make_train_op(self.value_net,
                                           self.global_value_net)
        self.pnet_train_op = make_train_op(self.policy_net,
                                           self.global_policy_net)

        self.state = None
        self.sim = None
    def __init__(self,
                 env,
                 env_id,
                 curriculum,
                 policy_net,
                 saver=None,
                 n_eval=10,
                 logfile=None,
                 checkpoint_path=None):

        self.env = env
        self.env_id = env_id
        self.curriculum = curriculum
        self.global_policy_net = policy_net
        self.saver = saver
        self.n_eval = n_eval
        self.checkpoint_path = checkpoint_path
        self.logger = logging.getLogger('eval runs {}'.format(env_id))
        hdlr = logging.FileHandler(logfile)
        formatter = logging.Formatter(
            '[%(asctime)s] [%(levelname)s] %(message)s')
        hdlr.setFormatter(formatter)
        self.logger.addHandler(hdlr)
        self.logger.setLevel(logging.INFO)

        # Local policy net
        with tf.variable_scope("policy_eval_{}".format(env_id)):
            self.policy_net = PolicyEstimator(
                policy_net.num_outputs, state_dims=self.env.get_state_size())

        #Directory to save checkpoints to. Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global_{}".format(env_id),
                collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval_{}".format(env_id),
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))
        self.epochs = 0
    def __init__(self, env, policy_net, summary_writer, saver=None):

        self.env = env
        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver
        self.counter = 0

        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))

        # Local policy net
        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(policy_net.num_outputs,
                                              policy_net.observation_space)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))
Ejemplo n.º 18
0
def main():

    # Depending on the game we may have a limited action space
    env_ = gym.make(FLAGS.env)
    num_actions = env_.action_space.n
    dim_obs = list(env_.observation_space.shape)
    assert len(dim_obs) == 3 and dim_obs[2] == 3  #make sure it is a RGB frame
    N_FRAME = FLAGS.n_frame if FLAGS.n_frame else 1
    dim_obs[2] *= N_FRAME
    print("Valid number of actions is {}".format(num_actions))
    print("The dimension of the observation space is {}".format(dim_obs))
    env_.close()

    # Set the number of workers
    NUM_WORKERS = (FLAGS.parallelism
                   if FLAGS.parallelism else multiprocessing.cpu_count())

    MODEL_DIR = FLAGS.model_dir
    CP_H = FLAGS.checkpoint_hour
    CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")
    TENSORBOARD_DIR = os.path.join(MODEL_DIR, "tb")

    # Optionally empty model directory
    if FLAGS.reset:
        shutil.rmtree(MODEL_DIR, ignore_errors=True)

    if not os.path.exists(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)

    summary_writer = tf.summary.FileWriter(TENSORBOARD_DIR)

    with tf.device("/cpu:0"):

        # Keeps track of the number of updates we've performed
        global_step = tf.Variable(0, name="global_step", trainable=False)

        # Global policy and value nets
        with tf.variable_scope("global") as vs:
            policy_net = PolicyEstimator(num_outputs=num_actions,
                                         dim_inputs=dim_obs)
            value_net = ValueEstimator(reuse=True, dim_inputs=dim_obs)

        # Global step iterator
        global_counter = itertools.count()

        # Create worker graphs
        workers = []
        for worker_id in range(NUM_WORKERS):
            # We only write summaries in one of the workers because they're
            # pretty much identical and writing them on all workers
            # would be a waste of space
            worker_summary_writer = None
            if worker_id == 0:
                worker_summary_writer = summary_writer

            worker = Worker(name="worker_{}".format(worker_id),
                            env=gym.make(FLAGS.env),
                            policy_net=policy_net,
                            value_net=value_net,
                            global_counter=global_counter,
                            discount_factor=0.99,
                            summary_writer=worker_summary_writer,
                            max_global_steps=FLAGS.max_global_steps,
                            n_frame=N_FRAME)
            workers.append(worker)

        saver = tf.train.Saver(keep_checkpoint_every_n_hours=CP_H,
                               max_to_keep=10)

        # Used to occasionally save videos for our policy net
        # and write episode rewards to Tensorboard
        pe = PolicyMonitor(env=gym.make(FLAGS.env),
                           policy_net=policy_net,
                           summary_writer=summary_writer,
                           saver=saver)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()

        # Load a previous checkpoint if it exists
        latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
        if latest_checkpoint:
            print("Loading model checkpoint: {}".format(latest_checkpoint))
            saver.restore(sess, latest_checkpoint)

        # Start worker threads
        worker_threads = []
        for worker in workers:
            print("starting worker:")
            worker_fn = lambda: worker.run(sess, coord, FLAGS.t_max)
            t = threading.Thread(target=worker_fn)
            t.start()
            worker_threads.append(t)

        # Start a thread for policy eval task
        monitor_thread = threading.Thread(
            target=lambda: pe.continuous_eval(FLAGS.eval_every, sess, coord))
        monitor_thread.start()

        # Wait for all workers to finish
        coord.join(worker_threads)
Ejemplo n.º 19
0
  shutil.rmtree(MODEL_DIR, ignore_errors=True)

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)


summary_writer = tf.summary.FileWriter(os.path.join(MODEL_DIR, "train"))

with tf.device("/cpu:0"):
    
    # Keeps track of the number of updates we've performed
    global_step = tf.Variable(0, name="global_step", trainable=False)
    
    # Global policy and value nets
    with tf.variable_scope("global") as vs:
        policy_net = PolicyEstimator(num_outputs=len(VALID_ACTIONS),
                                     observation_space=env_.observation_len)
        value_net = ValueEstimator(observation_space=env_.observation_len,
                                   reuse=True)
        
    # Global step iterator
    global_counter = itertools.count()


    # Create worker graphs
    workers = []
    for worker_id in range(NUM_WORKERS):
        # We only write summaries in one of the workers because they're
        # pretty much identical and writing them on all workers
        # would be a waste of space
        worker_summary_writer = None
        if worker_id == 0:
Ejemplo n.º 20
0
CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")

if FLAGS.reset:
    shutil.rmtree(MODEL_DIR, ignore_errors=True)

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

summary_writer = tf.summary.FileWriter(os.path.join(MODEL_DIR, "train"))

with tf.device("/cpu/0"):

    global_step = tf.Variable(0, name="globa_step", trainable=False)

    with tf.variable_scope("global") as vs:
        policy_net = PolicyEstimator(num_outputs=len(VALID_ACTIONS))
        value_net = ValueEstimator(reuse=True)

    global_counter = itertools.count()

    workers = []

    for worker_id in range(NUM_WORKERS):
        worker_summary_writer = None
        if worker_id == 0:
            worker_summary_writer = summary_writer

        worker = Worker(name="worker_{}".format(worker_id),
                        env=make_env(),
                        policy_net=policy_net,
                        value_net=value_net,
Ejemplo n.º 21
0
    def __init__(self,
                 name,
                 envs,
                 policy_net,
                 value_net,
                 global_counter,
                 domain,
                 instances,
                 discount_factor=0.99,
                 summary_writer=None,
                 max_global_steps=None):
        self.name = name
        self.domain = domain
        self.instances = instances
        self.dropout = 0.0
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_steps
        self.global_step = tf.contrib.framework.get_global_step()
        self.global_policy_net = policy_net
        self.global_value_net = value_net
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        self.summary_writer = summary_writer
        self.envs = envs
        self.n = self.envs[0].num_state_vars

        self.N = len(instances)
        self.current_instance = 0

        assert (policy_net.num_inputs == value_net.num_inputs)
        assert (self.N == len(self.envs))
        self.num_inputs = policy_net.num_inputs

        # Construct adjacency lists
        self.adjacency_lists = [None] * self.N
        self.nf_features = [None] * self.N
        self.single_adj_preprocessed_list = [None] * self.N

        for i in range(self.N):
            self.instance_parser = InstanceParser(self.domain,
                                                  self.instances[i])
            self.fluent_feature_dims, self.nonfluent_feature_dims = self.instance_parser.get_feature_dims(
            )
            self.nf_features[i] = self.instance_parser.get_nf_features()
            adjacency_list = self.instance_parser.get_adjacency_list()
            self.adjacency_lists[i] = nx.adjacency_matrix(
                nx.from_dict_of_lists(adjacency_list))
            self.single_adj_preprocessed_list[i] = preprocess_adj(
                self.adjacency_lists[i])

        # Create local policy/value nets that are not updated asynchronously
        with tf.variable_scope(name):
            self.policy_net = PolicyEstimator(
                policy_net.num_inputs, self.N, policy_net.num_hidden1,
                policy_net.num_hidden2, policy_net.num_outputs,
                policy_net.fluent_feature_dims,
                policy_net.nonfluent_feature_dims, policy_net.activation,
                policy_net.learning_rate)
            self.value_net = ValueEstimator(
                value_net.num_inputs, self.N, value_net.num_hidden1,
                value_net.num_hidden2, value_net.fluent_feature_dims,
                value_net.nonfluent_feature_dims, value_net.activation,
                value_net.learning_rate)

        # Op to copy params from global policy/valuenets
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.vnet_train_op_list = [None] * self.N
        self.pnet_train_op_list = [None] * self.N

        for i in range(self.N):
            self.vnet_train_op_list[i] = make_train_op(
                self.value_net, self.global_value_net, i)
            self.pnet_train_op_list[i] = make_train_op(
                self.policy_net, self.global_policy_net, i)

        self.state = None
Ejemplo n.º 22
0
    os.path.join(MODEL_DIR, "val_summaries"))

# Keeps track of the number of updates we've performed
global_step = tf.Variable(0, name="global_step", trainable=False)

global_learning_rate = FLAGS.lr

# Global policy and value nets
with tf.variable_scope("global") as vs:
    policy_net = PolicyEstimator(num_nodes_list=num_nodes_list,
                                 fluent_feature_dims=fluent_feature_dims,
                                 nonfluent_feature_dims=nonfluent_feature_dims,
                                 N=FLAGS.num_instances,
                                 num_valid_actions_list=num_valid_actions_list,
                                 action_details_list=action_details_list,
                                 num_graph_fluent_list=num_graph_fluent_list,
                                 num_gcn_hidden=policy_num_gcn_hidden,
                                 num_action_dim=FLAGS.num_action_dim,
                                 num_decoder_dim=FLAGS.num_decoder_dim,
                                 num_adjacency_list=num_adjacency_list,
                                 num_gat_layers=FLAGS.num_gat_layers,
                                 activation=FLAGS.activation,
                                 learning_rate=global_learning_rate)

    value_net = ValueEstimator(num_nodes_list=num_nodes_list,
                               fluent_feature_dims=fluent_feature_dims,
                               nonfluent_feature_dims=nonfluent_feature_dims,
                               N=FLAGS.num_instances,
                               num_graph_fluent_list=num_graph_fluent_list,
                               num_gcn_hidden=value_num_gcn_hidden,
                               num_action_dim=FLAGS.num_action_dim,
                               num_decoder_dim=FLAGS.num_decoder_dim,
Ejemplo n.º 23
0
                actor_lr*=.999
                break
            
            

with tf.Session() as sess:

    env = gym.envs.make("CartPole-v1")

    seed = FLAGS.seed
    np.random.seed(seed)
    env.seed(seed)
    tf.set_random_seed(seed)    
    print('seed: {} '.format(seed))

    policy_estimator = PolicyEstimator(env)
    value_estimator = ValueEstimator(env)

    sess.run(tf.global_variables_initializer())
    actor_critic(sess, env, policy_estimator,
                 value_estimator, FLAGS.num_episodes)









Ejemplo n.º 24
0
if not os.path.exists(CHECKPOINT_DIR):
  os.makedirs(CHECKPOINT_DIR)

if not os.path.exists(LOG_DIR):
  os.makedirs(LOG_DIR)

with tf.device("/cpu:0"):
  # Keeps track of the number of updates we've performed
  global_step = tf.Variable(0, name="global_step", trainable=False)

  # different policy and value nets for all tasks
  policy_nets = []
  value_nets = []
  for e in range(len(envs)):
    with tf.variable_scope("global_{}".format(e)) as vs:
      policy_nets.append(PolicyEstimator(
        num_outputs=len(VALID_ACTIONS), state_dims=envs[e].get_state_size()))
      value_nets.append(ValueEstimator(
        reuse=True, state_dims=envs[e].get_state_size()))
  if FLAGS.shared_final_layer:
    # make all final layer weights the same
    initial_copy_ops = []
    for e in range(1, len(envs)):
      initial_copy_ops += make_copy_params_op(
        tf.contrib.slim.get_variables(scope="global_0/policy_net", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
        tf.contrib.slim.get_variables(scope="global_{}/policy_net".format(e), collection=tf.GraphKeys.TRAINABLE_VARIABLES))
      initial_copy_ops += make_copy_params_op(
        tf.contrib.slim.get_variables(scope="global_0/value_net", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
        tf.contrib.slim.get_variables(scope="global_{}/value_net".format(e), collection=tf.GraphKeys.TRAINABLE_VARIABLES))

  # Global step iterator
  global_counter = itertools.count()
Ejemplo n.º 25
0
    def __init__(self,
                 envs,
                 policy_net,
                 domain,
                 instances,
                 neighbourhood,
                 summary_writer,
                 saver=None):

        self.stats_dir = os.path.join(summary_writer.get_logdir(), "../stats")
        self.stats_dir = os.path.abspath(self.stats_dir)

        self.domain = domain
        self.instances = instances
        self.N = len(instances)
        self.num_nodes_list = policy_net.num_nodes_list
        self.num_adjacency_list = policy_net.num_adjacency_list

        self.envs = envs
        self.global_policy_net = policy_net

        # Construct adjacency list
        self.adjacency_lists = [None] * self.N
        self.nf_features = [None] * self.N
        self.adjacency_lists_with_biases = [None] * self.N

        for i in range(self.N):
            self.fluent_feature_dims, self.nonfluent_feature_dims = self.envs[
                i].get_feature_dims()
            self.nf_features[i] = self.envs[i].get_nf_features()

            adjacency_list = self.envs[i].get_adjacency_list()
            self.adjacency_lists[i] = [
                get_adj_mat_from_list(aj) for aj in adjacency_list
            ]
            self.adjacency_lists_with_biases[i] = [
                process.adj_to_bias(np.array([aj]), [self.num_nodes_list[i]],
                                    nhood=neighbourhood)[0]
                for aj in self.adjacency_lists[i]
            ]

        self.summary_writer = summary_writer
        self.saver = saver

        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))

        try:
            os.makedirs(self.stats_dir)
        except:
            pass

        # Local policy net
        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(
                policy_net.num_nodes_list, policy_net.fluent_feature_dims,
                policy_net.nonfluent_feature_dims, policy_net.N,
                policy_net.num_valid_actions_list,
                policy_net.action_details_list,
                policy_net.num_graph_fluent_list, policy_net.num_gcn_hidden,
                policy_net.num_action_dim, policy_net.num_decoder_dim,
                policy_net.num_adjacency_list, policy_net.num_gat_layers,
                policy_net.activation, policy_net.learning_rate)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))
Ejemplo n.º 26
0
class PolicyMonitor(object):
    """
    Helps evaluating a policy by running an episode in an environment,
    saving a video, and plotting summaries to Tensorboard.

    Args:
        env: environment to run in
        policy_net: A policy estimator
        summary_writer: a tf.train.SummaryWriter used to write Tensorboard summaries
    """
    def __init__(self, env, policy_net, summary_writer, saver=None):

        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver
        self.env = env

        # Correct the path
        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))
        print('[PM] checkpoint_path: {}'.format(self.checkpoint_path))

        # Local policy net
        with tf.variable_scope("policy_eval"):
            if LSTM_POLICY:
                self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs)
            else:
                self.policy_net = PolicyEstimator(policy_net.num_outputs)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))

    def eval_once(self, sess):
        with sess.as_default(), sess.graph.as_default():
            # Copy params to local model
            global_step, _ = sess.run(
                [tf.contrib.framework.get_global_step(), self.copy_params_op])

            # Run an episode
            done = False
            state = self.env.reset()
            if LSTM_POLICY:
                self.policy_net.reset_lstm_features()
            total_reward = 0.0
            episode_length = 0
            while not done:
                if LSTM_POLICY:
                    action_probs = self.policy_net.action_inference(state)
                else:
                    action_probs = self.policy_net.action_prediction(state)
                action = np.random.choice(np.arange(len(action_probs)),
                                          p=action_probs)
                next_state, reward, done, _ = self.env.step(action)
                # next_state = atari_helpers.atari_make_next_state(state, self.sp.process(next_state))
                total_reward += reward
                episode_length += 1
                state = next_state

            # Add summaries
            episode_summary = tf.Summary()
            episode_summary.value.add(simple_value=total_reward,
                                      tag="eval/total_reward")
            episode_summary.value.add(simple_value=episode_length,
                                      tag="eval/episode_length")
            self.summary_writer.add_summary(episode_summary, global_step)
            self.summary_writer.flush()

            if self.saver is not None:
                self.saver.save(sess, self.checkpoint_path)

            tf.logging.info(
                "Eval results at step {}: total_reward {}, episode_length {}".
                format(global_step, total_reward, episode_length))

            return total_reward, episode_length

    def continuous_eval(self, eval_every, sess, coord):
        """
        Continuously evaluates the policy every [eval_every] seconds.
        """
        try:
            while not coord.should_stop():
                self.eval_once(sess)
                # Sleep until next evaluation cycle
                time.sleep(eval_every)
        except tf.errors.CancelledError:
            return
Ejemplo n.º 27
0
    def __init__(self,
                 name,
                 envs,
                 policy_net,
                 value_net,
                 global_counter,
                 domain,
                 instances,
                 N_train,
                 neighbourhood,
                 discount_factor=0.99,
                 summary_writer=None,
                 max_global_steps=None,
                 train_policy=True):
        self.name = name
        self.domain = domain
        self.instances = instances
        self.dropout = 0.0
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_steps
        self.global_step = tf.train.get_global_step()
        self.global_policy_net = policy_net
        self.global_value_net = value_net
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        self.summary_writer = summary_writer
        self.envs = envs
        self.num_adjacency_list = policy_net.num_adjacency_list

        self.N = len(instances)
        self.N_train = N_train
        self.current_instance = 0

        assert (self.N == len(self.envs))
        self.num_nodes_list = policy_net.num_nodes_list
        self.train_policy = train_policy

        # Construct adjacency lists
        self.adjacency_lists = [None] * self.N
        self.nf_features = [None] * self.N
        self.adjacency_lists_with_biases = [None] * self.N

        for i in range(self.N):
            self.fluent_feature_dims, self.nonfluent_feature_dims = self.envs[
                i].get_feature_dims()
            self.nf_features[i] = self.envs[i].get_nf_features()
            adjacency_list = self.envs[i].get_adjacency_list()

            self.adjacency_lists[i] = [
                get_adj_mat_from_list(aj) for aj in adjacency_list
            ]
            self.adjacency_lists_with_biases[i] = [
                process.adj_to_bias(
                    np.array([aj]), [self.num_nodes_list[i]],
                    nhood=neighbourhood)[0] for aj in self.adjacency_lists[i]
            ]

        # Create local policy/value nets that are not updated asynchronously
        with tf.variable_scope(name):
            self.policy_net = PolicyEstimator(
                policy_net.num_nodes_list, policy_net.fluent_feature_dims,
                policy_net.nonfluent_feature_dims, policy_net.N,
                policy_net.num_valid_actions_list,
                policy_net.action_details_list,
                policy_net.num_graph_fluent_list, policy_net.num_gcn_hidden,
                policy_net.num_action_dim, policy_net.num_decoder_dim,
                policy_net.num_adjacency_list, policy_net.num_gat_layers,
                policy_net.activation, policy_net.learning_rate)

            self.value_net = ValueEstimator(
                value_net.num_nodes_list, value_net.fluent_feature_dims,
                value_net.nonfluent_feature_dims, self.N,
                value_net.num_graph_fluent_list, value_net.num_gcn_hidden,
                value_net.num_action_dim, value_net.num_decoder_dim,
                value_net.num_adjacency_list, value_net.num_gat_layers,
                value_net.activation, value_net.learning_rate)

        # Op to copy params from global policy/valuenets
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.vnet_train_op_list = [None] * self.N

        if self.train_policy:
            self.pnet_train_op_list = [None] * self.N

        for i in range(self.N):
            self.vnet_train_op_list[i] = make_train_op(
                self.value_net, self.global_value_net, i)
            if self.train_policy:
                self.pnet_train_op_list[i] = make_train_op(
                    self.policy_net, self.global_policy_net, i)

        self.state = None

        self.start_time = time.time()
Ejemplo n.º 28
0
    def __init__(self,
                 envs,
                 policy_net,
                 domain,
                 instances,
                 summary_writer,
                 saver=None):

        self.stats_dir = os.path.join(summary_writer.get_logdir(), "../stats")
        self.stats_dir = os.path.abspath(self.stats_dir)

        self.n = envs[0].num_state_vars
        self.domain = domain
        self.instances = instances
        self.N = len(instances)

        self.envs = envs
        self.global_policy_net = policy_net

        # Construct adjacency list
        self.adjacency_lists = [None] * self.N
        self.single_adj_preprocessed_list = [None] * self.N

        for i in range(self.N):
            self.instance_parser = InstanceParser(self.domain,
                                                  self.instances[i])
            self.fluent_feature_dims, self.nonfluent_feature_dims = self.instance_parser.get_feature_dims(
            )
            self.nf_features = self.instance_parser.get_nf_features()
            adjacency_list = self.instance_parser.get_adjacency_list()
            self.adjacency_lists[i] = nx.adjacency_matrix(
                nx.from_dict_of_lists(adjacency_list))
            self.single_adj_preprocessed_list[i] = preprocess_adj(
                self.adjacency_lists[i])

        self.summary_writer = summary_writer
        self.saver = saver

        self.checkpoint_path = os.path.abspath(
            os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))

        try:
            os.makedirs(self.stats_dir)
        except:
            pass

        # Local policy net
        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(
                policy_net.num_inputs, policy_net.N, policy_net.num_hidden1,
                policy_net.num_hidden2, policy_net.num_hidden_transition,
                policy_net.num_outputs, policy_net.fluent_feature_dims,
                policy_net.nonfluent_feature_dims, policy_net.activation,
                policy_net.learning_rate)

        # Op to copy params from global policy/value net parameters
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope="policy_eval",
                collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.num_inputs = policy_net.num_inputs
Ejemplo n.º 29
0
if FLAGS.lr_decay:
    global_learning_rate = tf.train.exponential_decay(FLAGS.lr,
                                                      global_step,
                                                      500000,
                                                      0.3,
                                                      staircase=True)
else:
    global_learning_rate = FLAGS.lr

# Global policy and value nets
with tf.variable_scope("global") as vs:
    policy_net = PolicyEstimator(num_inputs=num_inputs,
                                 fluent_feature_dims=fluent_feature_dims,
                                 nonfluent_feature_dims=nonfluent_feature_dims,
                                 N=1,
                                 num_hidden1=policy_num_hidden1,
                                 num_hidden2=policy_num_hidden2,
                                 num_outputs=num_valid_actions,
                                 activation=FLAGS.activation,
                                 learning_rate=global_learning_rate)
    value_net = ValueEstimator(num_inputs=num_inputs,
                               fluent_feature_dims=fluent_feature_dims,
                               nonfluent_feature_dims=nonfluent_feature_dims,
                               N=1,
                               num_hidden1=value_num_hidden1,
                               num_hidden2=value_num_hidden2,
                               activation=FLAGS.activation,
                               learning_rate=global_learning_rate)

current_sa_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope='global/policy_net/gconv1_vars')
Ejemplo n.º 30
0
class Worker(object):
    """
    An A3C worker thread. Runs episodes locally and updates global shared value and policy nets.

    Args:
        name: A unique name for this worker
        env: The Gym environment used by this worker
        policy_net: Instance of the globally shared policy net
        value_net: Instance of the globally shared value net
        global_counter: Iterator that holds the global step
        discount_factor: Reward discount factor
        summary_writer: A tf.train.SummaryWriter for Tensorboard summaries
        max_global_steps: If set, stop coordinator when global_counter > max_global_steps
    """
    def __init__(self,
                 name,
                 env,
                 policy_net,
                 value_net,
                 global_counter,
                 discount_factor=0.99,
                 summary_writer=None,
                 max_global_steps=None):
        self.name = name
        self.discount_factor = discount_factor
        self.max_global_steps = max_global_steps
        self.global_step = tf.contrib.framework.get_global_step()
        self.global_policy_net = policy_net
        self.global_value_net = value_net
        self.global_counter = global_counter
        self.local_counter = itertools.count()
        self.summary_writer = summary_writer
        self.env = env

        # Create local policy/value nets that are not updated asynchronously
        with tf.variable_scope(name):
            if LSTM_POLICY:
                self.policy_net = LSTMPolicyEstimator(policy_net.num_outputs)
            else:
                self.policy_net = PolicyEstimator(policy_net.num_outputs)
            self.value_net = ValueEstimator(reuse=True)

        # Op to copy params from global policy/valuenets
        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(
                scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(
                scope=self.name, collection=tf.GraphKeys.TRAINABLE_VARIABLES))

        self.vnet_train_op = make_train_op(self.value_net,
                                           self.global_value_net)
        self.pnet_train_op = make_train_op(self.policy_net,
                                           self.global_policy_net)

        self.state = None

    def run(self, sess, coord, t_max):
        with sess.as_default(), sess.graph.as_default():
            # Initial state
            self.state = self.env.reset()
            try:
                while not coord.should_stop():
                    # Copy Parameters from the global networks
                    sess.run(self.copy_params_op)

                    # Collect some experience
                    transitions, local_t, global_t = self.run_n_steps(
                        t_max, sess)

                    if self.max_global_steps is not None and global_t >= self.max_global_steps:
                        tf.logging.info(
                            "Reached global step {}. Stopping.".format(
                                global_t))
                        coord.request_stop()
                        return

                    # Update the global networks
                    self.update(transitions, sess)

            except tf.errors.CancelledError:
                return

    def run_n_steps(self, n, sess):
        transitions = []
        if LSTM_POLICY:
            self.policy_net.reset_lstm_features()
        for _ in range(n):
            # Take a step
            if LSTM_POLICY:
                action_probs = self.policy_net.action_inference(self.state)
            else:
                action_probs = self.policy_net.action_prediction(self.state)

            # eps-greedy action
            action = np.random.choice(np.arange(len(action_probs)),
                                      p=action_probs)
            next_state, reward, done, _ = self.env.step(action)

            # Store transition
            transitions.append(
                Transition(state=self.state,
                           action=action,
                           reward=reward,
                           next_state=next_state,
                           done=done))

            # Increase local and global counters
            local_t = next(self.local_counter)
            global_t = next(self.global_counter)

            if local_t % 1000 == 0:
                tf.logging.info("{}: local Step {}, global step {}".format(
                    self.name, local_t, global_t))

            if done:
                self.state = self.env.reset()
                ### reset features
                if LSTM_POLICY:
                    self.policy_net.reset_lstm_features()
                break
            else:
                self.state = next_state
        return transitions, local_t, global_t

    def update(self, transitions, sess):
        """
        Updates global policy and value networks based on collected experience

        Args:
            transitions: A list of experience transitions
            sess: A Tensorflow session
        """

        # If we episode was not done we bootstrap the value from the last state
        reward = 0.0
        if not transitions[-1].done:
            reward = self.value_net.predict_value(transitions[-1].next_state)

        if LSTM_POLICY:
            init_lstm_state = self.policy_net.get_init_features()

        # Accumulate minibatch exmaples
        states = []
        policy_targets = []
        value_targets = []
        actions = []
        features = []

        for transition in transitions[::-1]:
            reward = transition.reward + self.discount_factor * reward
            policy_target = (reward -
                             self.value_net.predict_value(transition.state))
            # Accumulate updates
            states.append(transition.state)
            actions.append(transition.action)
            policy_targets.append(policy_target)
            value_targets.append(reward)

        if LSTM_POLICY:
            feed_dict = {
                self.policy_net.states: np.array(states),
                self.policy_net.targets: policy_targets,
                self.policy_net.actions: actions,
                self.policy_net.state_in[0]: np.array(init_lstm_state[0]),
                self.policy_net.state_in[1]: np.array(init_lstm_state[1]),
                self.value_net.states: np.array(states),
                self.value_net.targets: value_targets,
            }
        else:
            feed_dict = {
                self.policy_net.states: np.array(states),
                self.policy_net.targets: policy_targets,
                self.policy_net.actions: actions,
                self.value_net.states: np.array(states),
                self.value_net.targets: value_targets,
            }

        # Train the global estimators using local gradients
        global_step, pnet_loss, vnet_loss, _, _, pnet_summaries, vnet_summaries = sess.run(
            [
                self.global_step, self.policy_net.loss, self.value_net.loss,
                self.pnet_train_op, self.vnet_train_op,
                self.policy_net.summaries, self.value_net.summaries
            ], feed_dict)

        # Write summaries
        if self.summary_writer is not None and global_step % SUMMARY_EACH_STEPS == 0:
            self.summary_writer.add_summary(pnet_summaries, global_step)
            self.summary_writer.add_summary(vnet_summaries, global_step)
            self.summary_writer.flush()

        return pnet_loss, vnet_loss, pnet_summaries, vnet_summaries