Example #1
0
  def testCollectionGetSaver(self):
    with tf.variable_scope("prefix") as s1:
      input_ = tf.placeholder(tf.float32, shape=[3, 4])
      net = snt.Linear(10)(input_)
      net = snt.BatchNorm()(net, is_training=True)

    saver1 = snt.get_saver(s1)
    saver2 = snt.get_saver(s1, collections=(tf.GraphKeys.TRAINABLE_VARIABLES,))

    self.assertIsInstance(saver1, tf.train.Saver)
    self.assertIsInstance(saver2, tf.train.Saver)

    self.assertEqual(len(saver1._var_list), 5)
    self.assertIn("linear/w", saver1._var_list)
    self.assertIn("linear/b", saver1._var_list)
    self.assertIn("batch_norm/beta", saver1._var_list)
    self.assertIn("batch_norm/moving_mean", saver1._var_list)
    self.assertIn("batch_norm/moving_variance", saver1._var_list)

    self.assertEqual(len(saver2._var_list), 3)
    self.assertIn("linear/w", saver2._var_list)
    self.assertIn("linear/b", saver2._var_list)
    self.assertIn("batch_norm/beta", saver2._var_list)
    self.assertNotIn("batch_norm/moving_mean", saver2._var_list)
    self.assertNotIn("batch_norm/moving_variance", saver2._var_list)
Example #2
0
    def testCollectionGetSaver(self):
        with tf.variable_scope("prefix") as s1:
            input_ = tf.placeholder(tf.float32, shape=[3, 4])
            net = snt.Linear(10)(input_)
            net = snt.BatchNorm()(net, is_training=True)

        saver1 = snt.get_saver(s1)
        saver2 = snt.get_saver(
            s1, collections=(tf.GraphKeys.TRAINABLE_VARIABLES, ))

        self.assertIsInstance(saver1, tf.train.Saver)
        self.assertIsInstance(saver2, tf.train.Saver)

        self.assertEqual(len(saver1._var_list), 5)
        self.assertIn("linear/w", saver1._var_list)
        self.assertIn("linear/b", saver1._var_list)
        self.assertIn("batch_norm/beta", saver1._var_list)
        self.assertIn("batch_norm/moving_mean", saver1._var_list)
        self.assertIn("batch_norm/moving_variance", saver1._var_list)

        self.assertEqual(len(saver2._var_list), 3)
        self.assertIn("linear/w", saver2._var_list)
        self.assertIn("linear/b", saver2._var_list)
        self.assertIn("batch_norm/beta", saver2._var_list)
        self.assertNotIn("batch_norm/moving_mean", saver2._var_list)
        self.assertNotIn("batch_norm/moving_variance", saver2._var_list)
Example #3
0
    def _initialize(self):
        initialization_torso = tf.group(
            *[var.initializer for var in self._net_torso.variables])
        initialization_logit = tf.group(
            *[var.initializer for var in self._policy_logits_layer.variables])
        if self._loss_class.__name__ == "BatchA2CLoss":
            initialization_baseline_or_q_val = tf.group(
                *[var.initializer for var in self._baseline_layer.variables])
        else:
            initialization_baseline_or_q_val = tf.group(
                *[var.initializer for var in self._q_values_layer.variables])
        initialization_crit_opt = tf.group(
            *[var.initializer for var in self._critic_optimizer.variables()])
        initialization_pi_opt = tf.group(
            *[var.initializer for var in self._pi_optimizer.variables()])

        self._session.run(
            tf.group(*[
                initialization_torso, initialization_logit,
                initialization_baseline_or_q_val, initialization_crit_opt,
                initialization_pi_opt
            ]))
        self._savers = [("torso", snt.get_saver(self._net_torso)),
                        ("policy_head",
                         snt.get_saver(self._policy_logits_layer))]
        if self._loss_class.__name__ == "BatchA2CLoss":
            self._savers.append(
                ("baseline", snt.get_saver(self._baseline_layer)))
        else:
            self._savers.append(
                ("q_head", snt.get_saver(self._q_values_layer)))
Example #4
0
  def testCheckpointCompatibility(self):
    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")

    input_shape_1 = (31, 7, 7, 5)
    input_shape_2 = (31, 5, 7, 7)

    x1 = tf.placeholder(tf.float32, shape=input_shape_1)
    bn1 = snt.BatchNormV2(data_format="NHWC")
    bn1(x1, is_training=True)
    saver1 = snt.get_saver(bn1)

    x2 = tf.placeholder(tf.float32, shape=input_shape_2)
    bn2 = snt.BatchNormV2(data_format="NCHW")
    bn2(x2, is_training=False)
    saver2 = snt.get_saver(bn2)

    x3 = tf.placeholder(tf.float32, shape=input_shape_1)
    bn3 = snt.BatchNormV2(data_format="NCHW")
    bn3(x3, is_training=False)
    saver3 = snt.get_saver(bn3)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      saver1.save(sess, save_path)
      saver2.restore(sess, save_path)
      with self.assertRaises(tf.errors.InvalidArgumentError):
        saver3.restore(sess, save_path)
Example #5
0
  def testCheckpointCompatibility(self):
    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")

    input_shape_1 = (31, 7, 7, 5)
    input_shape_2 = (31, 5, 7, 7)

    x1 = tf.placeholder(tf.float32, shape=input_shape_1)
    bn1 = snt.BatchNormV2(data_format="NHWC")
    bn1(x1, is_training=True)
    saver1 = snt.get_saver(bn1)

    x2 = tf.placeholder(tf.float32, shape=input_shape_2)
    bn2 = snt.BatchNormV2(data_format="NCHW")
    bn2(x2, is_training=False)
    saver2 = snt.get_saver(bn2)

    x3 = tf.placeholder(tf.float32, shape=input_shape_1)
    bn3 = snt.BatchNormV2(data_format="NCHW")
    bn3(x3, is_training=False)
    saver3 = snt.get_saver(bn3)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      saver1.save(sess, save_path)
      saver2.restore(sess, save_path)
      with self.assertRaises(tf.errors.InvalidArgumentError):
        saver3.restore(sess, save_path)
def evaluate(crop_size_x, crop_size_y, feature_normalization, checkpoint_path,
             normalization_exclusion, eval_config, network_config):
  """Main evaluation loop."""
  experiment = contacts_experiment.Contacts(
      tfrecord=eval_config.eval_sstable,
      stats_file=eval_config.stats_file,
      network_config=network_config,
      crop_size_x=crop_size_x,
      crop_size_y=crop_size_y,
      feature_normalization=feature_normalization,
      normalization_exclusion=normalization_exclusion)

  checkpoint = snt.get_saver(experiment.model, collections=[
      tf.GraphKeys.GLOBAL_VARIABLES,
      tf.GraphKeys.MOVING_AVERAGE_VARIABLES])

  with tf.train.SingularMonitoredSession(hooks=[]) as sess:
    logging.info('Restoring from checkpoint %s', checkpoint_path)
    checkpoint.restore(sess, checkpoint_path)

    logging.info('Writing output to %s', eval_config.output_path)
    eval_begin_time = time.time()
    _run_evaluation(sess=sess,
                    experiment=experiment,
                    eval_config=eval_config,
                    output_dir=eval_config.output_path,
                    min_range=network_config.min_range,
                    max_range=network_config.max_range,
                    num_bins=network_config.num_bins,
                    torsion_bins=network_config.torsion_bins)
    logging.info('Finished eval %.1fs', (time.time() - eval_begin_time))
Example #7
0
 def testGetSaverModule(self):
     input_ = tf.placeholder(tf.float32, shape=[1, 10, 10, 3])
     conv = snt.Conv2D(output_channels=3, kernel_shape=3)
     conv(input_)
     saver = snt.get_saver(conv)
     self.assertIsInstance(saver, tf.train.Saver)
     self.assertIn("w", saver._var_list)
     self.assertIn("b", saver._var_list)
Example #8
0
    def testGetSaverScope(self):
        with tf.variable_scope("prefix") as s1:
            tf.get_variable("a", shape=[5, 6])
            tf.get_variable("b", shape=[7])

        saver = snt.get_saver(s1)
        self.assertIsInstance(saver, tf.train.Saver)
        self.assertEqual(set(saver._var_list), set(["a", "b"]))
Example #9
0
 def testGetSaverModule(self):
   input_ = tf.placeholder(tf.float32, shape=[1, 10, 10, 3])
   conv = snt.Conv2D(output_channels=3, kernel_shape=3)
   conv(input_)
   saver = snt.get_saver(conv)
   self.assertIsInstance(saver, tf.train.Saver)
   self.assertIn("w", saver._var_list)
   self.assertIn("b", saver._var_list)
Example #10
0
  def testGetSaverScope(self):
    with tf.variable_scope("prefix") as s1:
      tf.get_variable("a", shape=[5, 6])
      tf.get_variable("b", shape=[7])

    saver = snt.get_saver(s1)
    self.assertIsInstance(saver, tf.train.Saver)
    self.assertEqual(set(saver._var_list), set(["a", "b"]))
Example #11
0
    def testGetSaverPartitioned(self, save_partitioned, load_partitioned):
        path = os.path.join(tempfile.mkdtemp(), "ckpt")

        # Save checkpoint.
        with self.test_session() as sess:
            conv = self._create_conv(partitioned=save_partitioned, name="a")
            saver = snt.get_saver(conv)
            sess.run(tf.global_variables_initializer())
            saver.save(sess, path)
            w = tf.identity(conv.w)
            w_value = sess.run(w)

        # Restore checkpoint.
        with self.test_session() as sess:
            conv = self._create_conv(partitioned=load_partitioned, name="b")
            saver = snt.get_saver(conv)
            saver.restore(sess, path)
            w = tf.identity(conv.w)
            self.assertAllEqual(sess.run(w), w_value)
Example #12
0
  def testGetSaverPartitioned(self, save_partitioned, load_partitioned):
    path = os.path.join(tempfile.mkdtemp(), "ckpt")

    # Save checkpoint.
    with self.test_session() as sess:
      conv = self._create_conv(partitioned=save_partitioned, name="a")
      saver = snt.get_saver(conv)
      sess.run(tf.global_variables_initializer())
      saver.save(sess, path)
      w = tf.identity(conv.w)
      w_value = sess.run(w)

    # Restore checkpoint.
    with self.test_session() as sess:
      conv = self._create_conv(partitioned=load_partitioned, name="b")
      saver = snt.get_saver(conv)
      saver.restore(sess, path)
      w = tf.identity(conv.w)
      self.assertAllEqual(sess.run(w), w_value)
Example #13
0
    def _initialize(self):
        initialization_torso = tf.group(
            *[var.initializer for var in self._net_torso.variables])
        initialization_logit = tf.group(
            *[var.initializer for var in self._policy_logits_layer.variables])
        initialization_q_val = tf.group(
            *[var.initializer for var in self._q_values_layer.variables])
        initialization_crit_opt = tf.group(
            *[var.initializer for var in self._critic_optimizer.variables()])
        initialization_pi_opt = tf.group(
            *[var.initializer for var in self._pi_optimizer.variables()])

        self._session.run(
            tf.group(*[
                initialization_torso, initialization_logit,
                initialization_q_val, initialization_crit_opt,
                initialization_pi_opt
            ]))
        self._savers = [("torso", snt.get_saver(self._net_torso)),
                        ("policy_head",
                         snt.get_saver(self._policy_logits_layer))]
        self._savers.append(
            ("q_values_layer", snt.get_saver(self._q_values_layer)))
Example #14
0
step_op = optimizer.minimize(loss_op_tr)

# Lets an iterable of TF graphs be output from a session as NP graphs.
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)

#@title Reset session  { form-width: "30%" }

# This cell resets the Tensorflow session, but keeps the same computational
# graph.

try:
    sess.close()
except NameError:
    pass

saver = snt.get_saver(model)
sess = tf.Session()

#saver.restore(sess, "./tmp/model.ckpt")

sess = tf.Session()
sess.run(tf.global_variables_initializer())

last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []
Example #15
0
  def __init__(self,
               session,
               player_id,
               state_representation_size,
               num_actions,
               hidden_layers_sizes,
               replay_buffer_capacity=10000,
               batch_size=128,
               replay_buffer_class=ReplayBuffer,
               learning_rate=0.01,
               update_target_network_every=1000,
               learn_every=10,
               discount_factor=1.0,
               min_buffer_size_to_learn=1000,
               epsilon_start=1.0,
               epsilon_end=0.1,
               epsilon_decay_duration=int(1e6),
               optimizer_str="sgd",
               loss_str="mse"):
    """Initialize the DQN agent."""
    self.player_id = player_id
    self._session = session
    self._num_actions = num_actions
    self._layer_sizes = hidden_layers_sizes + [num_actions]
    self._batch_size = batch_size
    self._update_target_network_every = update_target_network_every
    self._learn_every = learn_every
    self._min_buffer_size_to_learn = min_buffer_size_to_learn
    self._discount_factor = discount_factor

    self._epsilon_start = epsilon_start
    self._epsilon_end = epsilon_end
    self._epsilon_decay_duration = epsilon_decay_duration

    # TODO(author6) Allow for optional replay buffer config.
    self._replay_buffer = replay_buffer_class(replay_buffer_capacity)
    self._prev_timestep = None
    self._prev_action = None

    # Step counter to keep track of learning, eps decay and target network.
    self._step_counter = 0

    # Keep track of the last training loss achieved in an update step.
    self._last_loss_value = None

    # Create required TensorFlow placeholders to perform the Q-network updates.
    self._info_state_ph = tf.placeholder(
        shape=[None, state_representation_size],
        dtype=tf.float32,
        name="info_state_ph")
    self._action_ph = tf.placeholder(
        shape=[None], dtype=tf.int32, name="action_ph")
    self._reward_ph = tf.placeholder(
        shape=[None], dtype=tf.float32, name="reward_ph")
    self._is_final_step_ph = tf.placeholder(
        shape=[None], dtype=tf.float32, name="is_final_step_ph")
    self._next_info_state_ph = tf.placeholder(
        shape=[None, state_representation_size],
        dtype=tf.float32,
        name="next_info_state_ph")
    self._legal_actions_mask_ph = tf.placeholder(
        shape=[None, num_actions],
        dtype=tf.float32,
        name="legal_actions_mask_ph")

    self._q_network = snt.nets.MLP(output_sizes=self._layer_sizes)
    self._q_values = self._q_network(self._info_state_ph)
    self._target_q_network = snt.nets.MLP(output_sizes=self._layer_sizes)
    self._target_q_values = self._target_q_network(self._next_info_state_ph)

    self._savers = [("network", snt.get_saver(self._q_network, max_to_keep=1000)),
                    ("target_network", snt.get_saver(self._target_q_network, max_to_keep=1000))]


    # Stop gradient to prevent updates to the target network while learning
    self._target_q_values = tf.stop_gradient(self._target_q_values)

    self._update_target_network = self._create_target_network_update_op(
        self._q_network, self._target_q_network)

    # Create the loss operations.
    # Sum a large negative constant to illegal action logits before taking the
    # max. This prevents illegal action values from being considered as target.
    illegal_actions = 1 - self._legal_actions_mask_ph
    illegal_logits = illegal_actions * ILLEGAL_ACTION_LOGITS_PENALTY
    max_next_q = tf.reduce_max(
        tf.math.add(tf.stop_gradient(self._target_q_values), illegal_logits),
        axis=-1)
    target = (
        self._reward_ph +
        (1 - self._is_final_step_ph) * self._discount_factor * max_next_q)

    action_indices = tf.stack(
        [tf.range(tf.shape(self._q_values)[0]), self._action_ph], axis=-1)
    predictions = tf.gather_nd(self._q_values, action_indices)

    if loss_str == "mse":
      loss_class = tf.losses.mean_squared_error
    elif loss_str == "huber":
      loss_class = tf.losses.huber_loss
    else:
      raise ValueError("Not implemented, choose from 'mse', 'huber'.")

    self._loss = tf.reduce_mean(
        loss_class(labels=target, predictions=predictions))

    if optimizer_str == "adam":
      optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    elif optimizer_str == "sgd":
      optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    else:
      raise ValueError("Not implemented, choose from 'adam' and 'sgd'.")

    self._learn_step = optimizer.minimize(self._loss)
Example #16
0
def train_and_evaluate(experiment, model_setup, meta, run_id, tr_iter, val_iter, eval_iter=None, save=False,
                       propensity_params=None, is_propensity_model=False):
    params = model_setup.model_params
    opt_params = model_setup.opt_params
    model_instance = str2model(model_setup.model_type)(params, meta["n_treatments"])

    tr_iter_next = tr_iter.get_next()
    val_iter_next = val_iter.get_next()
    eval_iter_next = eval_iter.get_next() if eval_iter is not None else None

    # Set up losses
    losses_objective = model_setup.model_params.train_loss.split(",")
    losses_to_record = list(set(losses_objective.copy()))
    if is_propensity_model:
        tr_losses, val_losses, eval_losses = _prepare_propensity_losses(model_instance, meta, tr_iter_next, val_iter_next)
    else:
        losses_to_record += experiment.additional_losses_to_record.split(",")
        tr_losses, val_losses, eval_losses = _prepare_losses(experiment, model_instance, opt_params, meta,
                                                        losses_to_record, tr_iter_next, val_iter_next, eval_iter_next)

    # Set up optimizer & initialization
    objective_loss_train_tensor = tf.add_n([tr_losses[loss] for loss in losses_objective if loss is not None], name="objective_loss")

    if not is_propensity_model and "propensity_model" in experiment.keys():
        # Keep propensity model variables frozen
        train_op = get_train_op(opt_params, objective_loss_train_tensor, experiment.propensity_model[0].model_type)
    else:
        train_op = get_train_op(opt_params, objective_loss_train_tensor)

    init_op = [tf.global_variables_initializer(), tr_iter.initializer, val_iter.initializer]
    if eval_iter is not None:
        init_op += [eval_iter.initializer]

    # Making trainable variables assignable so that they can be restored in early stopping
    trainable_vars = tf.trainable_variables()
    assigns_inputs = [tf.placeholder(dtype=var.dtype, name="assign" + str(i)) for i, var in enumerate(trainable_vars)]
    assigns = [tf.assign(var, assigns_inputs[i]) for i, var in enumerate(trainable_vars)]

    # Load propensity model
    saver = snt.get_saver(model_instance)
    if propensity_params is not None:
        propensity_model_instance = str2model(experiment.propensity_model[0].model_type)(propensity_params, meta["n_treatments"])
        propensity_saver = snt.get_saver(propensity_model_instance)

    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()

    # Training loop
    with tf.train.MonitoredTrainingSession() as session:
        session.run(init_op)

        if propensity_params is not None:
            prop_path = utils.assemble_model_path(experiment, experiment.propensity_model.model_name, run_id)
            propensity_saver.restore(session._sess._sess._sess._sess, tf.train.latest_checkpoint(prop_path))

        loss_records_val, best_loss_id = _train(experiment, model_setup, session, train_op, assigns, assigns_inputs, run_options, run_metadata,
               losses_objective, val_losses, losses_to_record)

        if save:
            save_model(experiment, saver, model_setup.model_name, run_id, session._sess._sess._sess._sess)

        # Create timeline for performance analysis
        if experiment.config.use_tracing:
            tl = timeline.Timeline(run_metadata.step_stats)
            ctf = tl.generate_chrome_trace_format()
            with open('timeline.json', 'w') as f:
                f.write(ctf)

        if eval_iter is not None:
            loss_records_eval = _evaluate(experiment, model_setup, session, eval_losses, losses_to_record)
        else:
            loss_records_eval = None

    return loss_records_val, best_loss_id, loss_records_eval
Example #17
0
def train(model_instance,
          params,
          opt_params,
          model_name,
          model_config_string,
          tr_data,
          val_data,
          meta,
          objective_loss_train,
          propensity_model_instance=None,
          additional_losses_to_record=[],
          save=False,
          use_train_data=True,
          run_id=0):
    n_treatments = meta["n_treatments"]

    print("Propensity Configuration: %s" % utils.params_to_string(params))
    print("%s" % utils.params_to_string(opt_params))

    x_tr, x_val, s_tr_pad, s_val_pad, t_tr, t_val, tr_iter_names, val_iter_names, tr_iter_origin, val_iter_origin = _prepare_data(
        tr_data, val_data, n_treatments, opt_params.batch_size, use_train_data,
        run_id)

    # Create the model
    mu_hat_tr, sig_hat_tr = model_instance(x_tr, True)
    mu_hat_val, sig_hat_val = model_instance(x_val, False)

    # Define losses
    objective_loss_train = utils.str2loss_list(params.train_loss)
    tr_loss_p, prob_tr = _propensity_loss(s_tr_pad, mu_hat_tr, sig_hat_tr)
    val_loss_p, prob_val = _propensity_loss(s_val_pad, mu_hat_val, sig_hat_val)

    tr_loss = tr_loss_p + _regularization_loss()

    # Define operations
    train_op = tf_utils.get_train_op(opt_params, tr_loss)
    init_op = [
        tf.global_variables_initializer(), val_iter_origin.initializer,
        tr_iter_origin.initializer
    ]

    # Making trainable variables assignable so that they can be restored in early stopping
    trainable_vars = tf.trainable_variables()
    assigns_inputs = [
        tf.placeholder(dtype=var.dtype, name="assign" + str(i))
        for i, var in enumerate(trainable_vars)
    ]
    assigns = [
        tf.assign(var, assigns_inputs[i])
        for i, var in enumerate(trainable_vars)
    ]

    best_loss = np.finfo(np.float32).max
    best_loss_id = 0
    loss_records_tr = {Loss.MSE_F.name: []}
    loss_records_val = {Loss.MSE_F.name: []}

    saver = snt.get_saver(model_instance)
    weights_stored = False

    # Training loop
    with tf.Session() as session:
        if config.data_from_database:
            session.run(init_op)
        else:
            feed_dicts = {
                tr_iter_names[k]: tr_data[k]
                for k in tr_iter_names.keys()
            }
            feed_dicts.update(
                {val_iter_names[k]: val_data[k]
                 for k in tr_iter_names.keys()})
            session.run(init_op, feed_dict=feed_dicts)

        for train_iter in range(opt_params.iterations):
            # Train
            session.run(train_op)

            # Record losses every x iterations
            if (train_iter > 15 and train_iter % config.print_interval_prop
                    == 0) or train_iter == opt_params.iterations - 1:
                curr_tr_loss = session.run(tr_loss_p)
                curr_val_loss = session.run(val_loss_p)

                loss_records_tr[Loss.MSE_F.name].append(curr_tr_loss)
                loss_records_val[Loss.MSE_F.name].append(curr_val_loss)

                print("Iter%04d:\tPropensity loss: %.3f\t%.3f" %
                      (train_iter, curr_tr_loss, curr_val_loss))

                # Break if loss takes on illegal value
                if np.isnan(curr_val_loss) or np.isnan(curr_tr_loss):
                    print("Illegal loss value. Aborting training.")
                    break

                # If loss improved: save weights, else: restore weights
                curr_loss = sum([
                    loss_records_val[loss.name][
                        len(loss_records_val[Loss.MSE_F.name]) - 1]
                    for loss in objective_loss_train
                ])
                if best_loss > curr_loss:
                    best_loss = curr_loss
                    best_loss_id = len(loss_records_val[Loss.MSE_F.name]) - 1
                    trainable_vars_values = session.run(trainable_vars)
                    weights_stored = True
                    print("Saving weights")

        # Restore variables of the best iteration
        if weights_stored:
            session.run(assigns,
                        dict(zip(assigns_inputs, trainable_vars_values)))

        if save:
            name = model_name
            name += str(run_id + (
                not use_train_data)) if config.data_from_database else ""
            tf_utils.save_model(saver, name, session)

    return loss_records_tr, loss_records_val, best_loss_id
Example #18
0
    def __init__(self,
                 session,
                 player_id,
                 info_state_size,
                 num_actions,
                 loss_str="rpg",
                 loss_class=None,
                 hidden_layers_sizes=(128, ),
                 batch_size=128,
                 critic_learning_rate=0.01,
                 pi_learning_rate=0.001,
                 entropy_cost=0.01,
                 num_critic_before_pi=8,
                 additional_discount_factor=1.0,
                 max_global_gradient_norm=None):
        """Initialize the PolicyGradient agent.

    Args:
      session: Tensorflow session.
      player_id: int, player identifier. Usually its position in the game.
      info_state_size: int, info_state vector size.
      num_actions: int, number of actions per info state.
      loss_str: string or None. If string, must be one of ["rpg", "qpg", "rm",
        "a2c"] and defined in `_get_loss_class`. If None, a loss class must be
        passed through `loss_class`. Defaults to "rpg".
      loss_class: Class or None. If Class, it must define the policy gradient
        loss. If None a loss class in a string format must be passed through
        `loss_str`. Defaults to None.
      hidden_layers_sizes: iterable, defines the neural network layers. Defaults
          to (128,), which produces a NN: [INPUT] -> [128] -> ReLU -> [OUTPUT].
      batch_size: int, batch size to use for Q and Pi learning. Defaults to 128.
      critic_learning_rate: float, learning rate used for Critic (Q or V).
        Defaults to 0.001.
      pi_learning_rate: float, learning rate used for Pi. Defaults to 0.001.
      entropy_cost: float, entropy cost used to multiply the entropy loss. Can
        be set to None to skip entropy computation. Defaults to 0.001.
      num_critic_before_pi: int, number of Critic (Q or V) updates before each
        Pi update. Defaults to 8 (every 8th critic learning step, Pi also
        learns).
      additional_discount_factor: float, additional discount to compute returns.
        Defaults to 1.0, in which case, no extra discount is applied.  None that
        users must provide *only one of* `loss_str` or `loss_class`.
      max_global_gradient_norm: float or None, maximum global norm of a gradient
        to which the gradient is shrunk if its value is larger.
    """
        assert bool(loss_str) ^ bool(
            loss_class), "Please provide only one option."
        loss_class = loss_class if loss_class else self._get_loss_class(
            loss_str)

        self.player_id = player_id
        self._session = session
        self._num_actions = num_actions
        self._layer_sizes = hidden_layers_sizes
        self._batch_size = batch_size
        self._extra_discount = additional_discount_factor
        self._num_critic_before_pi = num_critic_before_pi

        self._episode_data = []
        self._dataset = collections.defaultdict(list)
        self._prev_time_step = None
        self._prev_action = None

        # Step counters
        self._step_counter = 0
        self._episode_counter = 0
        self._num_learn_steps = 0

        # Keep track of the last training loss achieved in an update step.
        self._last_loss_value = None

        # Placeholders
        self._info_state_ph = tf.placeholder(shape=[None, info_state_size],
                                             dtype=tf.float32,
                                             name="info_state_ph")
        self._action_ph = tf.placeholder(shape=[None],
                                         dtype=tf.int32,
                                         name="action_ph")
        self._return_ph = tf.placeholder(shape=[None],
                                         dtype=tf.float32,
                                         name="return_ph")

        # Network
        # activate final as we plug logit and qvalue heads afterwards.
        net_torso = snt.nets.MLP(output_sizes=self._layer_sizes,
                                 activate_final=True)
        self._policy_head = snt.Linear(output_size=self._num_actions,
                                       name="policy_head")
        self.policy_logits_network = snt.Sequential(
            [net_torso, self._policy_head])
        self._policy_logits = self.policy_logits_network(self._info_state_ph)
        self._policy_probs = tf.nn.softmax(self._policy_logits)
        self._savers = [("torso", snt.get_saver(net_torso)),
                        ("policy_head", snt.get_saver(self._policy_head))]
        torso_out = net_torso(self._info_state_ph)
        # Add baseline (V) head for A2C.
        if loss_class.__name__ == "BatchA2CLoss":
            baseline = snt.Linear(output_size=1, name="baseline")
            self._baseline = tf.squeeze(baseline(torso_out), axis=1)
            self._savers.append(("baseline", snt.get_saver(baseline)))
        else:
            # Add q-values head otherwise
            q_head = snt.Linear(output_size=self._num_actions,
                                name="q_values_head")
            self._q_values = q_head(torso_out)
            self._savers.append(("q_head", snt.get_saver(q_head)))

        # Critic loss
        # Baseline loss in case of A2C
        if loss_class.__name__ == "BatchA2CLoss":
            self._critic_loss = tf.reduce_mean(
                tf.losses.mean_squared_error(labels=self._return_ph,
                                             predictions=self._baseline))
        else:
            # Q-loss otherwise.
            action_indices = tf.stack(
                [tf.range(tf.shape(self._q_values)[0]), self._action_ph],
                axis=-1)
            value_predictions = tf.gather_nd(self._q_values, action_indices)
            self._critic_loss = tf.reduce_mean(
                tf.losses.mean_squared_error(labels=self._return_ph,
                                             predictions=value_predictions))
        critic_optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=critic_learning_rate)

        def minimize_with_clipping(optimizer, loss):
            grads_and_vars = optimizer.compute_gradients(loss)
            if max_global_gradient_norm is not None:
                grads, variables = zip(*grads_and_vars)
                grads, _ = tf.clip_by_global_norm(grads,
                                                  max_global_gradient_norm)
                grads_and_vars = list(zip(grads, variables))

            return optimizer.apply_gradients(grads_and_vars)

        self._critic_learn_step = minimize_with_clipping(
            critic_optimizer, self._critic_loss)

        # Pi loss
        pg_class = loss_class(entropy_cost=entropy_cost)
        if loss_class.__name__ == "BatchA2CLoss":
            self._pi_loss = pg_class.loss(policy_logits=self._policy_logits,
                                          baseline=self._baseline,
                                          actions=self._action_ph,
                                          returns=self._return_ph)
        else:
            self._pi_loss = pg_class.loss(policy_logits=self._policy_logits,
                                          action_values=self._q_values)
        pi_optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=pi_learning_rate)
        self._pi_learn_step = minimize_with_clipping(pi_optimizer,
                                                     self._pi_loss)
        self._loss_str = loss_str