def update(sess, agent, target_agent, transitions, init_state=None, discount_factor=0.99, reward_norm=1.0, batch_size=32, time_major=False): loss = 0.0 time_len = transitions.state.shape[0] transitions_it = zip( iterate_minibatches(transitions.state, batch_size), iterate_minibatches(transitions.action, batch_size), iterate_minibatches(transitions.reward, batch_size), iterate_minibatches(transitions.next_state, batch_size), iterate_minibatches(transitions.done, batch_size)) for states, actions, rewards, next_states, dones in transitions_it: qvalues_next = agent.predict_qvalues(sess, next_states) best_actions = qvalues_next.argmax(axis=1) qvalues_next_target = target_agent.predict_qvalues(sess, next_states) qvalues_next_target = qvalues_next_target[np.arange(batch_size), best_actions] td_target = rewards * reward_norm + \ np.invert(dones).astype(np.float32) * \ discount_factor * qvalues_next_target run_params = [ agent.qvalue_net.loss, agent.qvalue_net.train_op, agent.hidden_state.train_op, agent.feature_net.train_op ] feed_params = { agent.feature_net.states: states, agent.feature_net.is_training: True, agent.qvalue_net.actions: actions, agent.qvalue_net.td_target: td_target, agent.qvalue_net.is_training: True, } if agent.special.get("dueling_network", False): run_params[0] = agent.agent_loss run_params += [agent.value_net.train_op] feed_params[agent.value_net. td_target] = td_target # @TODO: why need to feed? feed_params[agent.value_net.is_training] = True if isinstance(agent, DrqnAgent): run_params += [agent.hidden_state.belief_update] feed_params[agent.hidden_state.is_end] = dones run_results = sess.run(run_params, feed_dict=feed_params) batch_loss = run_results[0] loss += batch_loss return loss / time_len
def seq2seq_iter(data, batch_size, double=False): indices = np.arange(len(data)) for batch in iterate_minibatches(indices, batch_size): batch = [data[i] for i in batch] seq, target = zip(*batch) seq, seq_len = time_major_batch(seq) target, target_len = time_major_batch(target) yield seq, seq_len, target, target_len if double: yield target, target_len, seq, seq_len
def update(sess, a3c_agent, transitions, initial_state=None, discount_factor=0.99, reward_norm=1.0, batch_size=32, time_major=True): policy_targets = [] value_targets = [] state_history = [] action_history = [] done_history = [] cumulative_reward = np.zeros_like(transitions[-1].reward) + \ np.invert(transitions[-1].done) * \ a3c_agent.predict_values(sess, transitions[-1].next_state) for transition in reversed(transitions): cumulative_reward = reward_norm * transition.reward + \ np.invert(transition.done) * discount_factor * cumulative_reward policy_target = cumulative_reward - a3c_agent.predict_values( sess, transition.state) value_targets.append(cumulative_reward) policy_targets.append(policy_target) state_history.append(transition.state) action_history.append(transition.action) done_history.append(transition.done) value_targets = np.array(value_targets[::-1]) # time-major policy_targets = np.array(policy_targets[::-1]) state_history = np.array(state_history[::-1]) action_history = np.array(action_history[::-1]) done_history = np.array(done_history[::-1]) if isinstance(a3c_agent, A3CLstmAgent): a3c_agent.assign_belief_state(sess, initial_state) time_len = state_history.shape[0] value_loss, policy_loss = 0.0, 0.0 for state_axis, action_axis, value_target_axis, policy_target_axis, done_axis in \ zip(state_history, action_history, value_targets, policy_targets, done_history): axis_len = state_axis.shape[0] axis_value_loss, axis_policy_loss = 0.0, 0.0 state_axis = iterate_minibatches(state_axis, batch_size) action_axis = iterate_minibatches(action_axis, batch_size) value_target_axis = iterate_minibatches(value_target_axis, batch_size) policy_target_axis = iterate_minibatches(policy_target_axis, batch_size) done_axis = iterate_minibatches(done_axis, batch_size) batch_generator = merge_generators([ state_axis, action_axis, value_target_axis, policy_target_axis, done_axis ]) for state_batch, action_batch, value_target, policy_target, done_batch in batch_generator: run_params = [ a3c_agent.policy_net.loss, a3c_agent.value_net.loss, a3c_agent.policy_net.train_op, a3c_agent.value_net.train_op, a3c_agent.feature_net.train_op ] feed_params = { a3c_agent.feature_net.states: state_batch, a3c_agent.feature_net.is_training: True, a3c_agent.policy_net.actions: action_batch, a3c_agent.policy_net.cumulative_rewards: policy_target, a3c_agent.policy_net.is_training: True, a3c_agent.value_net.td_target: value_target, a3c_agent.value_net.is_training: True } if isinstance(a3c_agent, A3CLstmAgent): run_params += [a3c_agent.hidden_state.belief_update] feed_params[a3c_agent.hidden_state.is_end] = done_batch run_result = sess.run(run_params, feed_dict=feed_params) batch_loss_policy = run_result[0] batch_loss_state = run_result[1] axis_value_loss += batch_loss_state axis_policy_loss += batch_loss_policy policy_loss += axis_policy_loss / axis_len value_loss += axis_value_loss / axis_len return policy_loss / time_len, value_loss / time_len
def update(sess, reinforce_agent, transitions, initial_state=None, discount_factor=0.99, reward_norm=1.0, batch_size=32, time_major=True): policy_targets = [] state_history = [] action_history = [] cumulative_reward = np.zeros_like(transitions[-1].reward) for transition in reversed(transitions): cumulative_reward = reward_norm * transition.reward + \ np.invert(transition.done) * discount_factor * cumulative_reward policy_targets.append(cumulative_reward) state_history.append(transition.state) action_history.append(transition.action) # time-major policy_targets = np.array(policy_targets[::-1]) state_history = np.array(state_history[::-1]) action_history = np.array(action_history[::-1]) if not time_major: state_history = state_history.swapaxes(0, 1) action_history = action_history.swapaxes(0, 1) policy_targets = policy_targets.swapaxes(0, 1) time_len = state_history.shape[0] policy_loss = 0.0 for state_axis, action_axis, policy_target_axis in \ zip(state_history, action_history, policy_targets): axis_len = state_axis.shape[0] axis_policy_loss = 0.0 state_axis = iterate_minibatches(state_axis, batch_size) action_axis = iterate_minibatches(action_axis, batch_size) policy_target_axis = iterate_minibatches(policy_target_axis, batch_size) for state_batch, action_batch, policy_target in \ zip(state_axis, action_axis, policy_target_axis): run_params = [ reinforce_agent.policy_net.loss, reinforce_agent.policy_net.train_op, reinforce_agent.feature_net.train_op ] feed_params = { reinforce_agent.feature_net.states: state_batch, reinforce_agent.feature_net.is_training: True, reinforce_agent.policy_net.actions: action_batch, reinforce_agent.policy_net.cumulative_rewards: policy_target, reinforce_agent.policy_net.is_training: True } run_result = sess.run(run_params, feed_dict=feed_params) batch_loss_policy = run_result[0] axis_policy_loss += batch_loss_policy policy_loss += axis_policy_loss / axis_len return policy_loss / time_len