Пример #1
0
def update(sess,
           agent,
           target_agent,
           transitions,
           init_state=None,
           discount_factor=0.99,
           reward_norm=1.0,
           batch_size=32,
           time_major=False):
    loss = 0.0
    time_len = transitions.state.shape[0]

    transitions_it = zip(
        iterate_minibatches(transitions.state, batch_size),
        iterate_minibatches(transitions.action, batch_size),
        iterate_minibatches(transitions.reward, batch_size),
        iterate_minibatches(transitions.next_state, batch_size),
        iterate_minibatches(transitions.done, batch_size))

    for states, actions, rewards, next_states, dones in transitions_it:
        qvalues_next = agent.predict_qvalues(sess, next_states)
        best_actions = qvalues_next.argmax(axis=1)
        qvalues_next_target = target_agent.predict_qvalues(sess, next_states)
        qvalues_next_target = qvalues_next_target[np.arange(batch_size),
                                                  best_actions]

        td_target = rewards * reward_norm + \
                    np.invert(dones).astype(np.float32) * \
                    discount_factor * qvalues_next_target

        run_params = [
            agent.qvalue_net.loss, agent.qvalue_net.train_op,
            agent.hidden_state.train_op, agent.feature_net.train_op
        ]

        feed_params = {
            agent.feature_net.states: states,
            agent.feature_net.is_training: True,
            agent.qvalue_net.actions: actions,
            agent.qvalue_net.td_target: td_target,
            agent.qvalue_net.is_training: True,
        }

        if agent.special.get("dueling_network", False):
            run_params[0] = agent.agent_loss
            run_params += [agent.value_net.train_op]
            feed_params[agent.value_net.
                        td_target] = td_target  # @TODO: why need to feed?
            feed_params[agent.value_net.is_training] = True

        if isinstance(agent, DrqnAgent):
            run_params += [agent.hidden_state.belief_update]
            feed_params[agent.hidden_state.is_end] = dones

        run_results = sess.run(run_params, feed_dict=feed_params)

        batch_loss = run_results[0]
        loss += batch_loss
    return loss / time_len
Пример #2
0
def seq2seq_iter(data, batch_size, double=False):
    indices = np.arange(len(data))
    for batch in iterate_minibatches(indices, batch_size):
        batch = [data[i] for i in batch]
        seq, target = zip(*batch)
        seq, seq_len = time_major_batch(seq)
        target, target_len = time_major_batch(target)
        yield seq, seq_len, target, target_len
        if double:
            yield target, target_len, seq, seq_len
Пример #3
0
def update(sess,
           a3c_agent,
           transitions,
           initial_state=None,
           discount_factor=0.99,
           reward_norm=1.0,
           batch_size=32,
           time_major=True):
    policy_targets = []
    value_targets = []
    state_history = []
    action_history = []
    done_history = []

    cumulative_reward = np.zeros_like(transitions[-1].reward) + \
                        np.invert(transitions[-1].done) * \
                        a3c_agent.predict_values(sess, transitions[-1].next_state)
    for transition in reversed(transitions):
        cumulative_reward = reward_norm * transition.reward + \
                            np.invert(transition.done) * discount_factor * cumulative_reward
        policy_target = cumulative_reward - a3c_agent.predict_values(
            sess, transition.state)

        value_targets.append(cumulative_reward)
        policy_targets.append(policy_target)
        state_history.append(transition.state)
        action_history.append(transition.action)
        done_history.append(transition.done)

    value_targets = np.array(value_targets[::-1])  # time-major
    policy_targets = np.array(policy_targets[::-1])
    state_history = np.array(state_history[::-1])
    action_history = np.array(action_history[::-1])
    done_history = np.array(done_history[::-1])

    if isinstance(a3c_agent, A3CLstmAgent):
        a3c_agent.assign_belief_state(sess, initial_state)

    time_len = state_history.shape[0]
    value_loss, policy_loss = 0.0, 0.0
    for state_axis, action_axis, value_target_axis, policy_target_axis, done_axis in \
            zip(state_history, action_history, value_targets, policy_targets, done_history):
        axis_len = state_axis.shape[0]
        axis_value_loss, axis_policy_loss = 0.0, 0.0

        state_axis = iterate_minibatches(state_axis, batch_size)
        action_axis = iterate_minibatches(action_axis, batch_size)
        value_target_axis = iterate_minibatches(value_target_axis, batch_size)
        policy_target_axis = iterate_minibatches(policy_target_axis,
                                                 batch_size)
        done_axis = iterate_minibatches(done_axis, batch_size)

        batch_generator = merge_generators([
            state_axis, action_axis, value_target_axis, policy_target_axis,
            done_axis
        ])

        for state_batch, action_batch, value_target, policy_target, done_batch in batch_generator:
            run_params = [
                a3c_agent.policy_net.loss, a3c_agent.value_net.loss,
                a3c_agent.policy_net.train_op, a3c_agent.value_net.train_op,
                a3c_agent.feature_net.train_op
            ]
            feed_params = {
                a3c_agent.feature_net.states: state_batch,
                a3c_agent.feature_net.is_training: True,
                a3c_agent.policy_net.actions: action_batch,
                a3c_agent.policy_net.cumulative_rewards: policy_target,
                a3c_agent.policy_net.is_training: True,
                a3c_agent.value_net.td_target: value_target,
                a3c_agent.value_net.is_training: True
            }

            if isinstance(a3c_agent, A3CLstmAgent):
                run_params += [a3c_agent.hidden_state.belief_update]
                feed_params[a3c_agent.hidden_state.is_end] = done_batch

            run_result = sess.run(run_params, feed_dict=feed_params)

            batch_loss_policy = run_result[0]
            batch_loss_state = run_result[1]

            axis_value_loss += batch_loss_state
            axis_policy_loss += batch_loss_policy

        policy_loss += axis_policy_loss / axis_len
        value_loss += axis_value_loss / axis_len

    return policy_loss / time_len, value_loss / time_len
def update(sess,
           reinforce_agent,
           transitions,
           initial_state=None,
           discount_factor=0.99,
           reward_norm=1.0,
           batch_size=32,
           time_major=True):
    policy_targets = []
    state_history = []
    action_history = []

    cumulative_reward = np.zeros_like(transitions[-1].reward)
    for transition in reversed(transitions):
        cumulative_reward = reward_norm * transition.reward + \
                            np.invert(transition.done) * discount_factor * cumulative_reward

        policy_targets.append(cumulative_reward)
        state_history.append(transition.state)
        action_history.append(transition.action)

    # time-major
    policy_targets = np.array(policy_targets[::-1])
    state_history = np.array(state_history[::-1])
    action_history = np.array(action_history[::-1])

    if not time_major:
        state_history = state_history.swapaxes(0, 1)
        action_history = action_history.swapaxes(0, 1)
        policy_targets = policy_targets.swapaxes(0, 1)

    time_len = state_history.shape[0]

    policy_loss = 0.0
    for state_axis, action_axis, policy_target_axis in \
            zip(state_history, action_history, policy_targets):
        axis_len = state_axis.shape[0]
        axis_policy_loss = 0.0

        state_axis = iterate_minibatches(state_axis, batch_size)
        action_axis = iterate_minibatches(action_axis, batch_size)
        policy_target_axis = iterate_minibatches(policy_target_axis,
                                                 batch_size)

        for state_batch, action_batch, policy_target in \
                zip(state_axis, action_axis, policy_target_axis):
            run_params = [
                reinforce_agent.policy_net.loss,
                reinforce_agent.policy_net.train_op,
                reinforce_agent.feature_net.train_op
            ]
            feed_params = {
                reinforce_agent.feature_net.states: state_batch,
                reinforce_agent.feature_net.is_training: True,
                reinforce_agent.policy_net.actions: action_batch,
                reinforce_agent.policy_net.cumulative_rewards: policy_target,
                reinforce_agent.policy_net.is_training: True
            }

            run_result = sess.run(run_params, feed_dict=feed_params)

            batch_loss_policy = run_result[0]

            axis_policy_loss += batch_loss_policy

        policy_loss += axis_policy_loss / axis_len

    return policy_loss / time_len