def policy_model(inp, mdp, spec, name="policy", reuse=None, track_scope=None): """ Predict actions for the given input batch. Returns: actions: `batch_size * action_dim` """ # TODO remove magic numbers with tf.variable_scope(name, reuse=reuse, initializer=tf.truncated_normal_initializer(stddev=0.5)): return util.mlp(inp, mdp.state_dim, mdp.action_dim, hidden=spec.policy_dims, track_scope=track_scope)
def critic_model(inp, actions, mdp, spec, name="critic", reuse=None, track_scope=None): """ Predict the Q-value of the given state-action pairs. Returns: `batch_size` vector of Q-value predictions. """ with tf.variable_scope(name, reuse=reuse): output = util.mlp(tf.concat(1, [inp, actions]), mdp.state_dim + mdp.action_dim, 1, hidden=spec.critic_dims, bias_output=True, track_scope=track_scope) return tf.squeeze(output)
def policy_model(inp, mdp, spec, name="policy", reuse=None, track_scope=None): """ Predict actions for the given input batch. Returns: actions: `batch_size * action_dim` """ # TODO remove magic numbers with tf.variable_scope( name, reuse=reuse, initializer=tf.truncated_normal_initializer(stddev=0.5)): return util.mlp(inp, mdp.state_dim, mdp.action_dim, hidden=spec.policy_dims, track_scope=track_scope)