def __init__(self, cluster, server, task_idx, env): # Distributed tensorflow and logging related self.cluster = cluster self.env = env self.task_idx = task_idx self.leader_device = '/job:ps/task:0' self.worker_device = '/job:worker/task:%d' % task_idx self.num_workers = cluster.num_tasks('worker') # Buffers and parameters required to train self.curr_ep = 0 self.state_buf = [] self.action_buf = [] self.state_dim = env.state_dim self.action_cnt = env.action_cnt self.aug_state_dim = self.state_dim + self.action_cnt self.prev_action = self.action_cnt - 1 self.expert = TrueDaggerExpert(env) # Must call env.set_sample_action() before env.rollout() env.set_sample_action(self.sample_action) # Set up Tensorflow for synchronization, training self.setup_tf_ops() self.sess = tf.Session( server.target, config=tf.ConfigProto(allow_soft_placement=True)) self.sess.run(tf.global_variables_initializer())
class DaggerWorker(object): def __init__(self, cluster, server, task_idx, env): # Distributed tensorflow and logging related self.cluster = cluster self.env = env self.task_idx = task_idx self.leader_device = '/job:ps/task:0' self.worker_device = '/job:worker/task:%d' % task_idx self.num_workers = cluster.num_tasks('worker') # Buffers and parameters required to train self.curr_ep = 0 self.state_buf = [] self.action_buf = [] self.state_dim = env.state_dim self.action_cnt = env.action_cnt self.aug_state_dim = self.state_dim + self.action_cnt self.prev_action = self.action_cnt - 1 self.expert = TrueDaggerExpert(env) # Must call env.set_sample_action() before env.rollout() env.set_sample_action(self.sample_action) # Set up Tensorflow for synchronization, training self.setup_tf_ops() self.sess = tf.Session( server.target, config=tf.ConfigProto(allow_soft_placement=True)) self.sess.run(tf.global_variables_initializer()) def cleanup(self): self.env.cleanup() self.sess.run(self.sync_q.enqueue(Status.WORKER_DONE)) def setup_tf_ops(self): """ Sets up the shared Tensorflow operators and structures Refer to DaggerLeader for more information """ # Set up the shared global network and local network. with tf.device(self.leader_device): with tf.variable_scope('global_cpu'): self.global_network_cpu = DaggerLSTM( state_dim=self.aug_state_dim, action_cnt=self.action_cnt) with tf.device(self.worker_device): with tf.variable_scope('local'): self.local_network = DaggerLSTM(state_dim=self.aug_state_dim, action_cnt=self.action_cnt) self.init_state = self.local_network.zero_init_state(1) self.lstm_state = self.init_state # Build shared queues for training data and synchronization self.train_q = tf.FIFOQueue(self.num_workers, [tf.float32, tf.int32], shared_name='training_feed') self.sync_q = tf.FIFOQueue(3, [tf.int16], shared_name=('sync_q_%d' % self.task_idx)) # Training data is [[aug_state]], [action] self.state_data = tf.placeholder(tf.float32, shape=(None, self.aug_state_dim)) self.action_data = tf.placeholder(tf.int32, shape=(None)) self.enqueue_train_op = self.train_q.enqueue( [self.state_data, self.action_data]) # Sync local network to global network (CPU) local_vars = self.local_network.trainable_vars global_vars = self.global_network_cpu.trainable_vars self.sync_op = tf.group( *[v1.assign(v2) for v1, v2 in zip(local_vars, global_vars)]) def sample_action(self, state): """ Given a state buffer in the past step, returns an action to perform. Appends to the state/action buffers the state and the "correct" action to take according to the expert. """ cwnd = state[self.state_dim - 1] expert_action = self.expert.sample_action(cwnd) # For decision-making, normalize. norm_state = normalize(state) one_hot_action = one_hot(self.prev_action, self.action_cnt) aug_state = norm_state + one_hot_action # Fill in state_buf, action_buf self.state_buf.append(aug_state) self.action_buf.append(expert_action) # Always use the expert on the first episode to get our bearings. if self.curr_ep == 0: self.prev_action = expert_action return expert_action # Get probability of each action from the local network. pi = self.local_network feed_dict = { pi.input: [[aug_state]], pi.state_in: self.lstm_state, } ops_to_run = [pi.action_probs, pi.state_out] action_probs, self.lstm_state = self.sess.run(ops_to_run, feed_dict) # Choose an action to take and update current LSTM state # action = np.argmax(np.random.multinomial(1, action_probs[0][0] - 1e-5)) action = np.argmax(action_probs[0][0]) self.prev_action = action return action def rollout(self): """ Start an episode/flow with an empty dataset/environment. """ self.state_buf = [] self.action_buf = [] self.prev_action = self.action_cnt - 1 self.lstm_state = self.init_state self.env.reset() self.env.rollout() def run(self, debug=False): """Runs for max_ep episodes, each time sending data to the leader.""" pi = self.local_network while True: if debug: sys.stderr.write('[WORKER %d Ep %d] Starting...\n' % (self.task_idx, self.curr_ep)) # Reset local parameters to global self.sess.run(self.sync_op) print 'DaggerWorker:global_network_cpu:cnt', self.sess.run( self.global_network_cpu.cnt) print 'DaggerWorker:local_network:cnt', self.sess.run( self.local_network.cnt) sys.stdout.flush() # Start a single episode, populating state-action buffers. self.rollout() if debug: queue_size = self.sess.run(self.train_q.size()) sys.stderr.write( '[WORKER %d Ep %d]: enqueueing a sequence of data ' 'into queue of size %d\n' % (self.task_idx, self.curr_ep, queue_size)) # Enqueue a sequence of data into the training queue. self.sess.run(self.enqueue_train_op, feed_dict={ self.state_data: self.state_buf, self.action_data: self.action_buf }) self.sess.run(self.sync_q.enqueue(Status.EP_DONE)) if debug: queue_size = self.sess.run(self.train_q.size()) sys.stderr.write('[WORKER %d Ep %d]: finished queueing data. ' 'queue size now %d\n' % (self.task_idx, self.curr_ep, queue_size)) if debug: sys.stderr.write('[WORKER %d Ep %d]: waiting for server\n' % (self.task_idx, self.curr_ep)) # Let the leader dequeue EP_DONE time.sleep(0.5) # Wait until pserver finishes training by blocking on sync_q # Only proceeds when it finds a message from the pserver. msg = self.sess.run(self.sync_q.dequeue()) while (msg != Status.WORKER_START and msg != Status.PS_DONE): self.sess.run(self.sync_q.enqueue(msg)) time.sleep(0.5) msg = self.sess.run(self.sync_q.dequeue()) if msg == Status.PS_DONE: break self.curr_ep += 1