コード例 #1
0
    def _build(self, policy_optimizer, vf_optimizer, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            # Inputs to computation graph
            obs_ph = tf.placeholder(tf.float32, shape=(None, self.obs_dim), name='ppo_obs_ph')
            if self.discrete_actions and False:
                act_ph = tf.placeholder(tf.int32, shape=(None, self.act_dim), name='ppo_act_ph')
            else:
                act_ph = tf.placeholder(tf.float32, shape=(None, self.act_dim), name='ppo_act_ph')

            adv_ph = tf.placeholder(tf.float32, shape=(None), name='ppo_adv_ph')
            ret_ph = tf.placeholder(tf.float32, shape=(None), name='ppo_ret_ph')
            logp_old_ph = tf.placeholder(tf.float32, shape=(None), name='ppo_logp_old_ph')

            # Main outputs from computation graph
            pi, logp, logp_pi, v, act = self.actor_critic(obs_ph, act_ph, self.act_dim, self.discrete_actions)
            # PPO objectives
            ratio = tf.exp(logp - logp_old_ph)          # pi(a|s) / pi_old(a|s)
            min_adv = tf.where(adv_ph > 0, (1+self._clip_ratio) *
                               adv_ph, (1-self._clip_ratio)*adv_ph)
            pi_loss = -tf.reduce_mean(tf.minimum(ratio * adv_ph, min_adv))
            v_loss = tf.reduce_mean((ret_ph - v)**2)

            # Info (useful to watch during learning)
            # a sample estimate for KL-divergence, easy to compute
            approx_kl = tf.reduce_mean(logp_old_ph - logp)
            # a sample estimate for entropy, also easy to compute
            approx_ent = tf.reduce_mean(-logp)
            clipped = tf.logical_or(ratio > (1+self._clip_ratio), ratio < (1-self._clip_ratio))
            clipfrac = tf.reduce_mean(tf.cast(clipped, tf.float32))

            # Optimizers
            update_pi = policy_optimizer.minimize(pi_loss)
            update_v = vf_optimizer.minimize(v_loss)

            self.act = U.function(
                inputs=[obs_ph],
                outputs=[pi, v, logp_pi]
            )
            # Since we don't have a target network for ppo
            # we let the target_act function (as needed for
            # other models to update in a centralised manner)
            # simply be the action function.
            self.target_act = U.function([obs_ph], pi)
            self.train_pi = U.function(
                inputs=[obs_ph, act_ph, logp_old_ph, adv_ph],
                outputs=[pi_loss, approx_kl],
                updates=[update_pi]
            )
            self.train_v = U.function(
                inputs=[obs_ph, act_ph, ret_ph],
                outputs=v_loss,
                updates=[update_v]
            )
            self.trainable_vars = U.scope_vars(U.absolute_scope_name(scope))
コード例 #2
0
def om_train_block_processing(lstm_inputs,
                              observations_ph,
                              lstm,
                              action_pred_func,
                              opp_act_space,
                              num_units,
                              lstm_hidden_dim,
                              optimizer,
                              scope,
                              update_period_len,
                              history_embedding_dim,
                              grad_norm_clipping=None,
                              ablate_lstm=False,
                              reuse=None,
                              recurrent_prediction_module=False,
                              recurrent_prediction_dim=32,
                              use_representation_loss=False,
                              positive_dist=3,
                              negative_dist=15):
    with tf.variable_scope(scope, reuse=reuse):
        # Instantiate the opponent actions as a distribution.
        opp_act_space = make_pdtype(opp_act_space)

        # ------------------------ Episode Summarisation ------------------------
        # Store the batch size and then reshape the inputs such that they form a
        # a tensor of shape [number of periods x period length x lstm input shape]
        # This then enables the summary lstm model to map the sequences
        # of period length to a single vector.
        batch_size = tf.shape(observations_ph)[0]
        episodes = tf.reshape(lstm_inputs,
                              (-1, update_period_len, lstm_inputs.shape[-1]))
        # Run the summary model.
        episode_summaries, summary_vars = summarise_periods(
            episodes, history_embedding_dim, scope, reuse)
        get_ep_summaries = U.function([observations_ph, lstm_inputs],
                                      episode_summaries)
        # Reshape the outputs to be of the shape
        # batch_size x sequence length (in summary periods) x period embedding dimension
        # This is the final step in preparing the original inputs for passing them through
        # the main LSTM which models opponent learning.
        summaries = tf.reshape(episode_summaries,
                               (batch_size, -1, history_embedding_dim))

        # --------------------- Opponent Learning Modelling ---------------------
        # We set up a placeholder to denote whether or not the model is being
        # trained. During training full trajectories are passed in whereas
        # during play (test time essentially) one prediction is made at a time.
        # This boolean then allows us to account for the differing input shapes.
        training = tf.placeholder(tf.bool,
                                  shape=(),
                                  name='lemol_om_training_boolean')

        # Setting up an initial state for the LSTM such that it is trainable.
        initial_h = tf.Variable(tf.zeros((1, lstm_hidden_dim)),
                                trainable=True,
                                name='LeMOL_om_initial_h')
        initial_c = tf.Variable(tf.zeros((1, lstm_hidden_dim)),
                                trainable=True,
                                name='LeMOL_om_initial_c')

        # However we only want to used a learned state at t=0
        # We therefore create placeholders for a flag as to whether to
        # use the learned initial state or one that is passed in. This
        # then allows us to query the LSTM for any state and so it need
        # not be stateful (and hence does not track and update an internal
        # state automatically).
        use_initial_state = tf.placeholder(tf.bool,
                                           shape=(),
                                           name='use_learned_initial_state_ph')
        h_ph = tf.placeholder(tf.float32, (None, lstm_hidden_dim),
                              'LeMOL_om_h_ph')
        c_ph = tf.placeholder(tf.float32, (None, lstm_hidden_dim),
                              'LeMOL_om_c_ph')

        # Set up the state with the correct batch size (if using the
        # learned initial state).
        h = tf.cond(
            use_initial_state,
            lambda: tf.tile(initial_h,
                            (tf.shape(observations_ph)[0], 1)), lambda: h_ph)
        c = tf.cond(
            use_initial_state,
            lambda: tf.tile(initial_c,
                            (tf.shape(observations_ph)[0], 1)), lambda: c_ph)

        # Modelling the opponent learning process with an LSTM.
        # Taking the first three outputs is a fix to potentially using the
        # custom LeMOLLSTM (which has some internal learning feature generation
        # which we no longer use but are not yet prepared to fully remove).
        hidden_activations, final_h, final_c = lstm(summaries,
                                                    initial_state=[h, c])[:3]

        # Building the graph for the optional triplet loss is handled by
        # the LeMOL framework.
        if use_representation_loss:
            representation_loss_weight = tf.placeholder(
                tf.float32, (), 'representation_loss_weight')
            all_h = tf.concat([tf.expand_dims(h, 1), hidden_activations],
                              axis=1)
            representation_loss = build_triplet_loss(all_h, positive_dist,
                                                     negative_dist)

        # The hidden_activations (the h values) of the LSTM represent the
        # current point in learning of the opponent. There is one per
        # summarised period. However, we wish to make a prediction of
        # The opponent's action for each timestep in the following period.
        # We therefore start by prepending the initial h and not using the
        # final hidden state for prediction.
        # The following few lines of code therefore repeat these learning
        # phase representations for the full period of play they represent.
        # We first essentially add a dimension which we then repeat the
        # learning features over for the required number of times (update_period_len)
        # and finally put things back together so that these learning features
        # can be concatenated with the current observations to be used in opponent
        # action prediction.
        lf = tf.concat([tf.expand_dims(h, 1), hidden_activations[:, :-1]],
                       axis=1)
        lf = tf.reshape(lf, (batch_size, -1, 1, lstm_hidden_dim), name='zzz')
        lf = tf.tile(lf, (1, 1, update_period_len, 1))
        lf = tf.reshape(lf, (batch_size, -1, lstm_hidden_dim), 'kkk')

        # Create a placeholder to allow switching between the use of the
        # initial representation of the opponent's learning ('learning
        # feature') or a previously generated one (from the preceding
        # experience).
        use_initial_lf = tf.placeholder(tf.bool,
                                        shape=(),
                                        name='use_initial_learning_feature')

        # Nested conditionals using the boolean placeholders defined
        # previously to put together and reshape the learning features
        # to be used alongside current observations for opponent action
        # prediction.
        learning_features = tf.cond(
            training,
            lambda: lf,
            lambda: tf.cond(
                use_initial_lf,
                # If the initial learning feature is required we give
                # the learned initial h the right time dimension.
                lambda: tf.tile(tf.expand_dims(initial_h, 1),
                                (1, tf.shape(observations_ph)[1], 1)),
                # Otherwise we are feeding a learning feature (intended
                # to be used for in play prediction - note that this lf
                # is used for all observations showing an assumption that
                # in this case predictions are for a single stage of
                # opponent). Otherwise the learning features calculated
                # from the play period summaries are used.
                lambda: tf.tile(tf.expand_dims(h_ph, 1),
                                (1, tf.shape(observations_ph)[1], 1))))

        # The opponent model takes in the observations concatenated with
        # a learned representation of the current opponent (their state
        # of learning). The prediction function itself is defined elsewhere
        # and is assumed to be a multi-layered perceptron.

        if recurrent_prediction_module:
            opp_pred_input, h_in_ep, c_in_ep, recurrent_om_vars, recurrent_om_feeds, recurrent_om_debug = build_recurrent_om_module(
                learning_features=learning_features,
                observations=observations_ph,
                batch_size=batch_size,
                update_period_len=update_period_len,
                lstm_hidden_dim=lstm_hidden_dim,
                num_units=recurrent_prediction_dim,
                training_bool=training,
                ablate_lemol=ablate_lstm)
        else:
            opp_pred_input = tf.concat([observations_ph, learning_features],
                                       axis=-1)

        om_logits = action_pred_func(opp_pred_input,
                                     scope='action_pred_func',
                                     num_units=num_units,
                                     num_outputs=opp_act_space.ncat)

        # Given the logits we then use the distribution to sample
        # actions for the opponent. This induces some randomness.

        # We could reduce randomness by just taking the argmax
        # Does this randomness help to regularise the opponent model
        # during training and/or make for more realistic behaviour
        # for an agent using this opponent model?
        # To reduce this randomness we form actions_deter which is
        # simply a softmax distribution over opponent actions.
        opp_act_dist = opp_act_space.pdfromflat(om_logits)
        action_deter = U.softmax(om_logits)
        actions = opp_act_dist.sample()

        # Collect variables for training.
        # This seems to contain some repeat values but this does not matter.
        om_vars = U.scope_vars(U.absolute_scope_name('action_pred_func'))
        if recurrent_prediction_module:
            om_vars += recurrent_om_vars
        if not ablate_lstm:
            om_vars += summary_vars
            om_vars += lstm.weights
            om_vars += [initial_h, initial_c]

        # Opponent model training is performed as a regression problem targetting
        # the opponent's actions. The target values are therefore the opponents
        # observed actions which we aim to predict.
        target_ph = tf.placeholder(
            tf.float32, (None, None, opp_act_space.param_shape()[0]),
            name='om_actions_target')
        # Training used the softmax cross entropy loss which we hope to be
        # better behaved and smoother than a mean squared error loss.
        loss = U.mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(labels=target_ph,
                                                       logits=om_logits))
        # If we are using a representation loss we build it using the function
        # defined in the framework file.
        if use_representation_loss:
            loss += representation_loss_weight * representation_loss
        # Optimisations is conducted with the supplied optimiser with no variable
        # clipping. Optimisation is performed with respect to the variables collected
        # above.
        optimize_expr = U.minimize_and_clip(optimizer,
                                            loss,
                                            om_vars,
                                            clip_val=grad_norm_clipping)

        # We track and log accuracy of predictions during training as a
        # performance indicator. This is only a measure of top-1 accuracy
        # not of accuracy over the full opponent action distribution (policy).
        accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.argmax(target_ph, axis=-1),
                         tf.argmax(om_logits, axis=-1)), tf.float32))

        # The accuracy metric is also available to be logged to TensorBoard...
        accuracy_summary = tf.summary.scalar(
            'lemol_om_prediction_training_accuracy', accuracy)
        # ... as is the cross entropy loss.
        loss_summary = tf.summary.scalar('lemol_opponent_model_loss', loss)

        # House keeping for tensorboard.
        training_summaries = tf.summary.merge([loss_summary, accuracy_summary])

        # Finally we produce functions to run the relevant parts of the
        # computation graph constructed above. This simplifies calling later
        # since we do not need to worry about TensorFlow sessions etc.

        # The step function updates the state of the meta-learning LSTM, It
        # takes a series of LSTM inputs and observations and uses them to
        # calculate a new LSTM state which itself represents the learning
        # state of the opponent.
        step = U.function(inputs=[
            lstm_inputs, observations_ph, h_ph, c_ph, use_initial_state
        ],
                          outputs=[final_h, final_c])

        # The train function trains the LSTM and opponent prediction function to
        # predict the opponent's action given the history of experience and the
        # current observation.
        # We define the function using the utilities defined elsewhere and then
        # partially apply it such that the learning feature used in prediction
        # is always the on generated from the original inputs rather than one
        # passed in.
        # Control flow is in place to adapt to using the recurrent
        # in episode model as appropriate.
        training_inputs = [
            lstm_inputs, observations_ph, target_ph, h_ph, c_ph,
            use_initial_state, training, use_initial_lf
        ]
        training_outputs = [loss, training_summaries, final_h, final_c]
        if recurrent_prediction_module:
            training_inputs += [
                recurrent_om_feeds['h'], recurrent_om_feeds['c'],
                recurrent_om_feeds['use_initial_state']
            ]
        if use_representation_loss:
            training_inputs += [representation_loss_weight]

        train_full = U.function(inputs=training_inputs,
                                outputs=training_outputs,
                                updates=[optimize_expr])

        def train(i, o, t, h, c, init, w=0):
            if recurrent_prediction_module:
                # We always need to use the initial state for the in play recurrent model
                # as we require that the inputs fed in for training are in complete episodes
                # and we then reshape them to process them one episode at a time.
                if use_representation_loss:
                    return train_full(i, o, t, h, c, init, True, False,
                                      np.zeros((1, recurrent_prediction_dim)),
                                      np.zeros((1, recurrent_prediction_dim)),
                                      True, w)
                else:
                    return train_full(i, o, t, h, c, init, True, False,
                                      np.zeros((1, recurrent_prediction_dim)),
                                      np.zeros((1, recurrent_prediction_dim)),
                                      True)
            else:
                # The extra recurrent model inputs are superfluous.
                if use_representation_loss:
                    return train_full(i, o, t, h, c, init, True, False, w)
                else:
                    return train_full(i, o, t, h, c, init, True, False)

        # The act function essentially performs action prediction.
        # The inputs are many and varied because we need to feed
        # values into the graph for all possible paths even if the
        # boolean placeholders are such that only one path will ever
        # be used. To simplify things we therefore partially apply
        # the function created such that the lstm_inputs fed are
        # never used (and so we may freely set their value to 0).
        # This is achieved by setting `use_initial_state` to False
        # so that when `use_initial_lf` is False the learning feature
        # is given by the value passed to `h_ph`. `training` is fixed
        # to be False so that the learning feature is not calculated from
        # `lstm_inputs` but is taken from either the learned initial value
        # for the value fed in for `h_ph` (according to the value of
        # `use_initial_lf`).
        # The function `act_default` then uses the input observations
        # concatenated with either the initial or the supplied learning
        # feature to predict opponent actions.
        # Control flow is again in place to adapt to using the recurrent
        # in episode model as appropriate.
        act_inputs = [
            observations_ph, h_ph, c_ph, use_initial_lf, training,
            use_initial_state, lstm_inputs
        ]
        act_outputs = [action_deter]
        # We need slightly more inputs if we use a recurrent
        # model within each episode.
        if recurrent_prediction_module:
            act_inputs += [
                recurrent_om_feeds['h'], recurrent_om_feeds['c'],
                recurrent_om_feeds['use_initial_state']
            ]
            act_outputs += [h_in_ep, c_in_ep]

        act_full = U.function(inputs=act_inputs, outputs=act_outputs)

        def act(o, h, l, h2=None, c2=None, init=False):
            if recurrent_prediction_module:
                return act_full(
                    o, h, np.zeros_like(h), l, False, False,
                    np.zeros((int(o.shape[0]), update_period_len,
                              int(lstm_inputs.shape[-1]))), h2, c2, init)
            else:
                # The extra recurrent model inputs are superfluous.
                return act_full(
                    o, h, np.zeros_like(h), l, False, False,
                    np.zeros((int(o.shape[0]), update_period_len,
                              int(lstm_inputs.shape[-1]))))

        # We do the same for the opponent model logits (useful
        # for debugging) which require the same inputs and
        # outputs as the action prediction calculation.
        logits_full = U.function(inputs=act_inputs, outputs=om_logits)

        def logits(o, h, l, h2=None, c2=None, init=False):
            if recurrent_prediction_module:
                return logits_full(
                    o, h, np.zeros_like(h), l, False, False,
                    np.zeros((int(o.shape[0]), update_period_len,
                              int(lstm_inputs.shape[-1]))), h2, c2, init)
            else:
                # The extra recurrent model inputs are superfluous.
                return logits_full(
                    o, h, np.zeros_like(h), l, False, False,
                    np.zeros((int(o.shape[0]), update_period_len,
                              int(lstm_inputs.shape[-1]))))

        debug_dict = {
            'om_logits': logits,
            'initial_h': initial_h,
            'initial_c': initial_c,
            'summaries': get_ep_summaries
        }

        if recurrent_prediction_module:
            debug_dict['initial_h_in_ep'] = recurrent_om_debug['initial_h']
            debug_dict['initial_c_in_ep'] = recurrent_om_debug['initial_c']
        return act, step, train, debug_dict
コード例 #3
0
def p_train(obs_ph_n,
            opp_act_ph,
            act_space_n,
            p_index,
            p_func,
            q_func,
            optimizer,
            grad_norm_clipping=None,
            num_units=64,
            scope='trainer',
            reuse=None,
            polyak=1e-4,
            decentralised_obs=False):
    # p_index is the agent index.
    with tf.variable_scope(scope, reuse=reuse):
        # create action distributions
        act_pdtype_n = [make_pdtype(act_space) for act_space in act_space_n]

        # set up placeholders for actions
        act_ph_n = [
            act_pdtype_n[i].sample_placeholder([None], name='action' + str(i))
            for i in range(len(act_space_n))
        ]

        # Concatenate the observation of the agent being trained with the opponent action
        # (predicted) as input to the policy function.
        p_input = tf.concat([obs_ph_n[p_index], opp_act_ph], -1)

        # Attain policy distribution (logits) using an mlp (or some such predictive function).
        p_logits = p_func(p_input,
                          int(act_pdtype_n[p_index].param_shape()[0]),
                          scope='p_func',
                          num_units=num_units)
        p_func_vars = U.scope_vars(U.absolute_scope_name('p_func'))

        # Turn logits into marginal action distribution.
        act_pd = act_pdtype_n[p_index].pdfromflat(p_logits)

        # Sample Actions and attain mean square action probability for regularisation.
        act_sample = act_pd.sample()
        p_reg = tf.reduce_mean(tf.square(p_logits))

        # Prepare the input to the Q function.
        # Add an empty list to ensure deep copy.
        act_input_n = act_ph_n + []
        # Replace action placeholder for agent being trained with sampled action.
        act_input_n[p_index] = act_sample

        # Centralised q function takes in all observations and all actions.
        if decentralised_obs:
            q_input = tf.concat([obs_ph_n[p_index]] + act_input_n, 1)
        else:
            q_input = tf.concat(obs_ph_n + act_input_n, 1)

        # Attain q values
        q = q_func(q_input, 1, scope='q_func', reuse=True,
                   num_units=num_units)[:, 0]

        # Objective is to maximise q values which are implemented as scalars
        pg_loss = -tf.reduce_mean(q)

        # Calculate the loss with regularisation.
        loss = pg_loss + p_reg * 1e-3

        # Optimisation operation.
        optimize_expr = U.minimize_and_clip(optimizer, loss, p_func_vars,
                                            grad_norm_clipping)

        # Create callable functions for training, actions and policy.
        train = U.function(inputs=obs_ph_n + [opp_act_ph] + act_ph_n,
                           outputs=loss,
                           updates=[optimize_expr])

        act = U.function(inputs=[obs_ph_n[p_index], opp_act_ph],
                         outputs=act_sample)

        logits = U.function(inputs=[obs_ph_n[p_index], opp_act_ph],
                            outputs=p_logits)

        # target network to stabilise training.
        target_logits = p_func(p_input,
                               int(act_pdtype_n[p_index].param_shape()[0]),
                               scope='target_p_func',
                               num_units=num_units)
        target_p_func_vars = U.scope_vars(
            U.absolute_scope_name('target_p_func'))

        # create operation to update target network towards true net
        update_target_p = U.make_update_exp(p_func_vars,
                                            target_p_func_vars,
                                            polyak=polyak)

        # Function for attaining target actions to be used in training.
        target_act_sample = act_pdtype_n[p_index].pdfromflat(
            target_logits).sample()
        target_act = U.function(inputs=[obs_ph_n[p_index], opp_act_ph],
                                outputs=target_act_sample)

        # Calculate the gradient of the output of the policy function with
        # respect to the opponent model prediction as a measure of influence
        # for debugging and analysis.
        om_influence_grad = tf.reduce_mean(
            tf.square(tf.gradients(p_logits, opp_act_ph)))

        om_influence_summary = tf.summary.scalar('LeMOL_om_influence',
                                                 om_influence_grad)
        om_influence = U.function(
            inputs=[obs_ph_n[p_index], opp_act_ph],
            outputs=[om_influence_grad, om_influence_summary])
        # ---------------- END OF OM INFLUENCE CALCULATIONS ----------------

        all_vars = p_func_vars + target_p_func_vars

        return (act, train, update_target_p, all_vars, {
            'logits': logits,
            'target_logits': target_logits,
            'target_act': target_act,
            'om_influence': om_influence
        })
コード例 #4
0
def q_train(obs_ph_n,
            act_space_n,
            q_index,
            q_func,
            optimizer,
            grad_norm_clipping=None,
            scope='trainer',
            reuse=None,
            num_units=64,
            polyak=1e-4,
            decentralised_obs=False):
    '''
    Arguments
    make_obs_ph_n
    act_space_n
    q_index - int - The index of the agent being trained.
    q_func
    optimizer - tf.Optimizer - Tensorflow Optimizer object used to minimise the loss.
    '''
    with tf.variable_scope(scope, reuse=reuse):
        # create distributions for actions
        act_pdtype_n = [make_pdtype(act_space) for act_space in act_space_n]

        # set up placeholders for actions
        act_ph_n = [
            act_pdtype_n[i].sample_placeholder([None], name='action' + str(i))
            for i in range(len(act_space_n))
        ]
        target_ph = tf.placeholder(tf.float32, [None], name='target')

        # Collect all observations and actions together if performing
        # centralised training.
        if decentralised_obs:
            q_input = tf.concat([obs_ph_n[q_index]] + act_ph_n, 1)
        else:
            q_input = tf.concat(obs_ph_n + act_ph_n, 1)

        # The q value for the state-action pair.
        q = tf.squeeze(q_func(q_input, 1, scope='q_func', num_units=num_units))
        q_func_vars = U.scope_vars(U.absolute_scope_name('q_func'))

        # Train on the squared loss to an externally constructed target.
        loss = tf.reduce_mean(tf.square(q - target_ph))

        optimize_expr = U.minimize_and_clip(optimizer, loss, q_func_vars,
                                            grad_norm_clipping)

        # Create callable functions for training and attaining Q values
        train = U.function(inputs=obs_ph_n + act_ph_n + [target_ph],
                           outputs=[loss, q],
                           updates=[optimize_expr])
        q_values = U.function(obs_ph_n + act_ph_n, q)

        # target network
        target_q = tf.squeeze(
            q_func(q_input, 1, scope='target_q_func', num_units=num_units))
        target_q_func_vars = U.scope_vars(
            U.absolute_scope_name('target_q_func'))

        # Create operation to update target q-network parameters towards trained q-net
        update_target_q = U.make_update_exp(q_func_vars,
                                            target_q_func_vars,
                                            polyak=polyak)

        target_q_values = U.function(obs_ph_n + act_ph_n, target_q)

        all_vars = q_func_vars + target_q_func_vars

        return train, update_target_q, all_vars, {
            'q_values': q_values,
            'target_q_values': target_q_values
        }
コード例 #5
0
def om_train(lstm_inputs,
             observations_ph,
             lstm,
             action_pred_func,
             opp_act_space,
             num_units,
             lstm_hidden_dim,
             optimizer,
             scope,
             episode_len=None,
             history_embedding_dim=None,
             grad_norm_clipping=None,
             ablate_lstm=False,
             reuse=None,
             use_representation_loss=False,
             positive_dist=5,
             negative_dist=40):
    with tf.variable_scope(scope, reuse=reuse):
        # ----------------------------------- SET UP -----------------------------------
        # Set up the opponent policy being modelled as a distribution.
        opp_act_space = make_pdtype(opp_act_space)
        # --------------------------------- END SET UP ---------------------------------

        # ------------------------------- META MODELLING -------------------------------
        # For the LSTM tracking training, set up a learnable initial state.
        # This is the same for any sequence and therefore has first dimension
        # 1. We tile this to fit to the batch size flexibly/
        initial_h = tf.Variable(tf.zeros((1, lstm_hidden_dim)),
                                trainable=True,
                                name='LeMOL_om_initial_h')
        initial_c = tf.Variable(tf.zeros((1, lstm_hidden_dim)),
                                trainable=True,
                                name='LeMOL_om_initial_c')
        # Adding a boolean to allow switching between using the initial
        # state instantiated above and using a state fed in to the graph
        # from an external source (via the placeholders h_ph and c_ph)
        use_initial_state = tf.placeholder(tf.bool,
                                           shape=(),
                                           name='use_learned_initial_state_ph')
        h_ph = tf.placeholder(tf.float32, (None, lstm_hidden_dim),
                              'LeMOL_om_h_ph')
        c_ph = tf.placeholder(tf.float32, (None, lstm_hidden_dim),
                              'LeMOL_om_c_ph')

        # Use the above to form the initial state for the LSTM which models
        # the opponent (and their learning). This is done by either tiling
        # the learned initial state to the batch size as calculated by the
        # size of the observations placeholder or by simply passing through
        # the placeholder which can be fed with an item of any batch size.
        h = tf.cond(
            use_initial_state,
            lambda: tf.tile(initial_h,
                            (tf.shape(observations_ph)[0], 1)), lambda: h_ph)
        c = tf.cond(
            use_initial_state,
            lambda: tf.tile(initial_h,
                            (tf.shape(observations_ph)[0], 1)), lambda: c_ph)

        # Model opponent learning using an LSTM. We run lstm inputs through
        # and use the hidden activations as a representation of where the
        # opponent is in their learning process.
        # Note that we index to take the first 3 outputs to be able to run
        # consistently across the custom LeMOL LSTM and standard LSTM
        # implementations. Note that this essentially makes the custom
        # LSTM run like standard implementations.
        hidden_activations, final_h, final_c = lstm(lstm_inputs,
                                                    initial_state=[h, c])[:3]

        # If we are using a representation loss we build it using the function
        # defined in the framework file.
        if use_representation_loss:
            representation_loss_weight = tf.placeholder(
                tf.float32, (), 'representation_loss_weight')
            all_h = tf.concat([tf.expand_dims(h, 1), hidden_activations],
                              axis=1)
            representation_loss = build_triplet_loss(all_h, positive_dist,
                                                     negative_dist)
        # --------------------------- END OF META MODELLING ---------------------------

        # ----------------------------- ACTION PREDICTION -----------------------------
        # Initial logic to allow running the LSTM to update learning features or simply
        # pass in previously calculated values.
        # We use the lagged output from the LSTM because in training we try to predict
        # the current opponent action and therefore must use the learning feature from
        # before the current opponent action is known.
        run_lstm = tf.placeholder(tf.bool,
                                  shape=(),
                                  name='om_training_boolean')
        learning_features = tf.cond(
            run_lstm, lambda: tf.concat(
                [tf.expand_dims(h, 1), hidden_activations[:, :-1]], axis=1),
            lambda: tf.tile(tf.expand_dims(h, 1),
                            (1, tf.shape(observations_ph)[1], 1)))
        # We model the opponent's action using the current observation and the modelled
        # learning process feature. Action prediction itself then only considers the
        # context of history through the 'meta modelling' LSTM.
        opp_policy_input = tf.concat([observations_ph, learning_features],
                                     axis=-1)
        if ablate_lstm:
            opp_policy_input = observations_ph
        # Use the function passed in to attain logits for the estimated opponent policy.
        # This passed in function is generally a multi-layered perceptron.
        om_logits = action_pred_func(opp_policy_input,
                                     scope='action_pred_func',
                                     num_units=num_units,
                                     num_outputs=opp_act_space.ncat)

        # Given the logits, form the opponent policy as a distribution so that actions
        # can be sampled if desired.
        # TODO use argmax, logits or sample from dist? Currently return logits in step\
        # and the sampled action otherwise.
        opp_act_dist = opp_act_space.pdfromflat(om_logits)
        action_deter = U.softmax(om_logits)
        actions = opp_act_dist.sample()

        # ---------------------------- END ACTION PREDICTION ----------------------------

        # ----------------------------------- TRAINING -----------------------------------
        # Collect weights to train from the LSTM (inc. the initial state) and the action
        # prediction function. They are all trained together where relevant.
        om_vars = U.scope_vars(U.absolute_scope_name('action_pred_func'))
        if not ablate_lstm:
            om_vars += lstm.weights
            om_vars += [initial_h, initial_c]

        # Loss calculation.
        # We require target values - the true actions we wish to predict.
        # The loss is then the cross entropy loss between the predicted and actual action
        # distributions.
        target_ph = tf.placeholder(
            tf.float32, (None, None, opp_act_space.param_shape()[0]),
            name='om_actions_target')
        loss = U.mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(labels=target_ph,
                                                       logits=om_logits))
        if use_representation_loss:
            loss += representation_loss_weight * representation_loss

        # Finally for training, set up an update function for the loss, minimised over the
        # variables of the opponent model (as collected above).
        optimize_expr = U.minimize_and_clip(optimizer,
                                            loss,
                                            om_vars,
                                            clip_val=grad_norm_clipping)
        # --------------------------------- END TRAINING ---------------------------------

        # ------------------------------- METRICS & LOGGING -------------------------------
        # Calculate one hot prediction accuracy as a heuristic metric for prediction
        # performance.
        accuracy = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.argmax(target_ph, axis=-1),
                         tf.argmax(om_logits, axis=-1)), tf.float32))

        # Accuracy is recorded to tensorboard as a summary as is the loss.
        accuracy_summary = tf.summary.scalar(
            'lemol_om_prediction_training_accuracy', accuracy)
        loss_summary = tf.summary.scalar('lemol_opponent_model_loss', loss)
        # We merge these summaries to be able to run them with ease.
        training_summaries = tf.summary.merge([loss_summary, accuracy_summary])
        # ----------------------------- END METRICS & LOGGING -----------------------------

        # ------------------------------- FUNCTION BUILDING -------------------------------
        # step 'increments' the LSTM updating the state and uses the new state and input to
        # generate an opponent action prediction. The use of a deterministic softmax for
        # action selection over a sampled action removes the stochasticity of sampling
        # actions from the modelled distribution. The LSTM state is returned so that it can
        # later be passed in as play continues in a given trajectory.
        step_full = U.function(inputs=[
            lstm_inputs, observations_ph, h_ph, c_ph, use_initial_state,
            run_lstm
        ],
                               outputs=[final_h, final_c])

        def step(i, o, h, c, init):
            return step_full(i, o, h, c, init, True)

        # A function to attain logits is made and added to the debugging output dictionary.
        # This provides access to a non-stochastic opponent modelling outcome without
        # explicitly being concerned with the new state of the LSTM.
        logits_full = U.function(inputs=[
            lstm_inputs, observations_ph, h_ph, c_ph, use_initial_state,
            run_lstm
        ],
                                 outputs=om_logits)

        def logits(o, h, init=False):
            return logits_full(
                np.zeros((o.shape[0], 1, int(lstm_inputs.shape[-1]))), o, h,
                np.zeros_like(h), init, False)

        # act provides and access to the estimated opponent action.
        act_full = U.function(inputs=[
            lstm_inputs, observations_ph, h_ph, c_ph, use_initial_state,
            run_lstm
        ],
                              outputs=action_deter)

        def act(o, h, init=False):
            return act_full(
                np.zeros((o.shape[0], 1, int(lstm_inputs.shape[-1]))), o, h,
                np.zeros_like(h), init, False)

        # Provide a simple interface to train the model.
        # This function updates the weights of the opponent model returning
        # the loss value, summaries for tensorboard and the state of the LSTM
        # at the end of the sequence passed in which can then be used for a
        # subsequent sequence if needed (as long trajectories are broken into
        # chunks to be processed in turn).
        # The inputs required are those for the lstm (inputs, and the state),
        # observations to then make the action predictions, targets to calculate
        # the loss and a boolean to mark whether or not to use the initial state
        # as will be required at the start of a new (batch of) sequence(s).
        # Importantly when using the learned initial state the passed in state is
        # ignored but must still be passed in as all possible computation paths
        # through the graph must be passed in since boolean conditions are evaluated
        # lazily and inputs validated greedily.
        train_inputs = [
            lstm_inputs, observations_ph, target_ph, h_ph, c_ph,
            use_initial_state, run_lstm
        ]
        if use_representation_loss:
            train_inputs += [representation_loss_weight]
        train_full = U.function(
            inputs=train_inputs,
            outputs=[loss, training_summaries, final_h, final_c],
            updates=[optimize_expr])

        def train(i, o, t, h, c, init, w=0):
            if use_representation_loss:
                return train_full(i, o, t, h, c, init, True, w)
            else:
                return train_full(i, o, t, h, c, init, True)

        # --------------------------------- END FUNCTION BUILDING ---------------------------------

        return act, step, train, {
            'logits': logits,
            'initial_h': initial_h,
            'initial_c': initial_c
        }