def _initialize_checkpointer_and_maybe_resume(self,
                                                  checkpoint_file_prefix):
        """Reloads the latest checkpoint if it exists.
    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoiRnt in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the agent is able to successfully unbundle, this method will verify that
    the unbundled data contains the keys,'logs' and 'current_iteration'. It will
    then load the `Logger`'s data from the bundle, and will return the iteration
    number keyed by 'current_iteration' as one of the return values (along with
    the `Checkpointer` object).
    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.
    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
        self._checkpointer = checkpointer.Checkpointer(
            self.PATH + "checkpoints", checkpoint_file_prefix)

        self._checkpointer2 = checkpointer.Checkpointer(
            self.PATH + "player2", checkpoint_file_prefix)
        self._checkpointerlatest1 = checkpointer.Checkpointer(
            self.PATH + "latest3", checkpoint_file_prefix)
        self._checkpointerlatest2 = checkpointer.Checkpointer(
            self.PATH + "latest4", checkpoint_file_prefix)
        self._checkpointertest = checkpointer.Checkpointer(
            self.PATH + "test", checkpoint_file_prefix)

        self._start_iteration = 0
        # Check if checkpoint exists. Note that the existence of checkpoint 0 means
        # that we have finished iteration 0 (so we will start from iteration 1).
        latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
            self.PATH + "checkpoints")
        print("Latest version:", latest_checkpoint_version)
        return 5
        if latest_checkpoint_version >= 0:
            experiment_data = self._checkpointer.load_checkpoint(
                latest_checkpoint_version)
            print("Read to apple")
            if self._agent.unbundle(self.PATH + "checkpoints",
                                    latest_checkpoint_version,
                                    experiment_data):

                assert 'logs' in experiment_data
                assert 'current_iteration' in experiment_data
                self._logger.data = experiment_data['logs']
                self._start_iteration = experiment_data['current_iteration'] + 1
                tf.logging.info(
                    'Reloaded checkpoint and will start from iteration %d',
                    self._start_iteration)
                #self.testing = True
                #self._agent.testing = True
                print('TESTING:', self.testing)
                print("BANANA")
Esempio n. 2
0
    def _initialize_checkpointer_and_maybe_resume(self,
                                                  checkpoint_file_prefix):
        super(FixedReplayRunner,
              self)._initialize_checkpointer_and_maybe_resume(
                  checkpoint_file_prefix)

        # Code for the loading a checkpoint at initialization
        init_checkpoint_dir = (self._agent._init_checkpoint_dir)  # pylint: disable=protected-access
        if (self._start_iteration == 0) and (init_checkpoint_dir is not None):
            if checkpointer.get_latest_checkpoint_number(
                    self._checkpoint_dir) < 0:
                # No checkpoint loaded yet, read init_checkpoint_dir
                init_checkpointer = checkpointer.Checkpointer(
                    init_checkpoint_dir, checkpoint_file_prefix)
                latest_init_checkpoint = checkpointer.get_latest_checkpoint_number(
                    init_checkpoint_dir)
                if latest_init_checkpoint >= 0:
                    experiment_data = init_checkpointer.load_checkpoint(
                        latest_init_checkpoint)
                    if self._agent.unbundle(init_checkpoint_dir,
                                            latest_init_checkpoint,
                                            experiment_data):
                        if experiment_data is not None:
                            assert "logs" in experiment_data
                            assert "current_iteration" in experiment_data
                            self._logger.data = experiment_data["logs"]
                            self._start_iteration = (
                                experiment_data["current_iteration"] + 1)
                        tf.logging.info(
                            "Reloaded checkpoint from %s and will start from iteration %d",
                            init_checkpoint_dir,
                            self._start_iteration,
                        )
Esempio n. 3
0
    def run_experiment(self):
        """Runs a full experiment, spread over multiple iterations."""
        tf.logging.info('Beginning evaluation...')
        # Use the checkpointer class.
        self._checkpointer = checkpointer.Checkpointer(
            self._checkpoint_dir, self._checkpoint_file_prefix)
        checkpoint_version = -1
        # Check new checkpoints in a loop.
        while True:
            # Check if checkpoint exists.
            # Note that the existence of checkpoint 0 means that we have finished
            # iteration 0 (so we will start from iteration 1).
            latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
                self._checkpoint_dir)
            # checkpoints_iterator already makes sure a new checkpoint exists.
            if latest_checkpoint_version <= checkpoint_version:
                time.sleep(self._min_interval_secs)
                continue
            checkpoint_version = latest_checkpoint_version
            experiment_data = self._checkpointer.load_checkpoint(
                latest_checkpoint_version)
            assert self._agent.unbundle(self._checkpoint_dir,
                                        latest_checkpoint_version,
                                        experiment_data)

            self._run_eval_phase(experiment_data['total_steps'])
            if self._test_mode:
                break
Esempio n. 4
0
    def testGarbageCollectionWithCheckpointFrequency(self):
        custom_prefix = 'custom_prefix'
        checkpoint_frequency = 3
        exp_checkpointer = checkpointer.Checkpointer(
            self._test_subdir,
            checkpoint_file_prefix=custom_prefix,
            checkpoint_frequency=checkpoint_frequency)
        data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
        deleted_log_files = 6
        total_log_files = (checkpointer.CHECKPOINT_DURATION *
                           checkpoint_frequency) + deleted_log_files + 1

        # The checkpoints will happen in iteration numbers 0,3,6,9,12,15,18.
        # We are checking if checkpoints 0,3,6 are deleted.
        for iteration_number in range(total_log_files):
            exp_checkpointer.save_checkpoint(iteration_number, data)
        for iteration_number in range(total_log_files):
            prefixes = [custom_prefix, 'sentinel_checkpoint_complete']
            for prefix in prefixes:
                checkpoint_file = os.path.join(
                    self._test_subdir, '{}.{}'.format(prefix,
                                                      iteration_number))
                if iteration_number <= deleted_log_files:
                    self.assertFalse(tf.gfile.Exists(checkpoint_file))
                else:
                    if iteration_number % checkpoint_frequency == 0:
                        self.assertTrue(tf.gfile.Exists(checkpoint_file))
                    else:
                        self.assertFalse(tf.gfile.Exists(checkpoint_file))
Esempio n. 5
0
def initialize_checkpointer(checkpoint_dir, checkpoint_file_prefix, agent):
    """Reloads the latest checkpoint if it exists.

    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoint in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the agent is able to successfully unbundle, this method will verify that
    the unbundled data contains the keys,'logs' and 'current_iteration'. It will
    then load the `Logger`'s data from the bundle, and will return the iteration
    number keyed by 'current_iteration' as one of the return values (along with
    the `Checkpointer` object).

    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.

    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
    checkpointer_ = checkpointer.Checkpointer(checkpoint_dir,
                                              checkpoint_file_prefix)
    start_iteration = 0
    # Check if checkpoint exists. Note that the existence of checkpoint 0 means
    # that we have finished iteration 0 (so we will start from iteration 1).
    latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
        checkpoint_dir)
    if latest_checkpoint_version >= 0:
        experiment_data = checkpointer_.load_checkpoint(
            latest_checkpoint_version)
        agent.unbundle(checkpoint_dir, latest_checkpoint_version,
                       experiment_data)
Esempio n. 6
0
 def testCheckpointingInitialization(self):
     # Fails with empty directory.
     with self.assertRaisesRegexp(ValueError,
                                  'No path provided to Checkpointer.'):
         checkpointer.Checkpointer('')
     # Fails with invalid directory.
     invalid_dir = '/does/not/exist'
     with self.assertRaisesRegexp(
             ValueError,
             'Unable to create checkpoint path: {}.'.format(invalid_dir)):
         checkpointer.Checkpointer(invalid_dir)
     # Succeeds with valid directory.
     checkpointer.Checkpointer('/tmp/dopamine_tests')
     # This verifies initialization still works after the directory has already
     # been created.
     self.assertTrue(tf.gfile.Exists('/tmp/dopamine_tests'))
     checkpointer.Checkpointer('/tmp/dopamine_tests')
Esempio n. 7
0
 def testLoadLatestCheckpoint(self):
     exp_checkpointer = checkpointer.Checkpointer(self._test_subdir)
     first_iter = 1729
     exp_checkpointer.save_checkpoint(first_iter, first_iter)
     second_iter = first_iter + 1
     exp_checkpointer.save_checkpoint(second_iter, second_iter)
     self.assertEqual(
         second_iter,
         checkpointer.get_latest_checkpoint_number(self._test_subdir))
Esempio n. 8
0
 def testLogToFileWithValidDirectoryDefaultPrefix(self):
     exp_checkpointer = checkpointer.Checkpointer(self._test_subdir)
     data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
     iteration_number = 1729
     exp_checkpointer.save_checkpoint(iteration_number, data)
     loaded_data = exp_checkpointer.load_checkpoint(iteration_number)
     self.assertEqual(data, loaded_data)
     self.assertEqual(
         None, exp_checkpointer.load_checkpoint(iteration_number + 1))
Esempio n. 9
0
 def testLogToFileWithValidDirectoryCustomPrefix(self):
     prefix = 'custom_prefix'
     exp_checkpointer = checkpointer.Checkpointer(
         self._test_subdir, checkpoint_file_prefix=prefix)
     data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
     iteration_number = 1729
     exp_checkpointer.save_checkpoint(iteration_number, data)
     loaded_data = exp_checkpointer.load_checkpoint(iteration_number)
     self.assertEqual(data, loaded_data)
     self.assertIsNone(
         exp_checkpointer.load_checkpoint(iteration_number + 1))
Esempio n. 10
0
    def _initialize_checkpointer_and_maybe_resume(self,
                                                  checkpoint_file_prefix):
        """Reloads the latest checkpoint if it exists.

    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoint in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the agent is able to successfully unbundle, this method will verify that
    the unbundled data contains the keys,'logs' and 'current_iteration'. It will
    then load the `Logger`'s data from the bundle, and will return the iteration
    number keyed by 'current_iteration' as one of the return values (along with
    the `Checkpointer` object).

    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.

    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
        self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                       checkpoint_file_prefix)
        self._start_iteration = 0
        # Check if checkpoint exists. Note that the existence of checkpoint 0 means
        # that we have finished iteration 0 (so we will start from iteration 1).
        latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
            self._checkpoint_dir)
        if latest_checkpoint_version >= 0:
            experiment_data = self._checkpointer.load_checkpoint(
                latest_checkpoint_version)
            if self._agent.unbundle(self._checkpoint_dir,
                                    latest_checkpoint_version,
                                    experiment_data):
                if experiment_data is not None:
                    assert 'logs' in experiment_data
                    assert 'current_iteration' in experiment_data
                    self._logger.data = experiment_data['logs']
                    self._start_iteration = experiment_data[
                        'current_iteration'] + 1
                    if self._environment.game_name[0:4] == 'VGDL':
                        self._environment.set_level(
                            experiment_data['vgdl_level'],
                            experiment_data['training_steps'])
                tf.logging.info(
                    'Reloaded checkpoint and will start from iteration %d',
                    self._start_iteration)
Esempio n. 11
0
 def testGarbageCollection(self):
   custom_prefix = 'custom_prefix'
   exp_checkpointer = checkpointer.Checkpointer(
       self._test_subdir, checkpoint_file_prefix=custom_prefix)
   data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
   deleted_log_files = 7
   total_log_files = checkpointer.CHECKPOINT_DURATION + deleted_log_files
   for iteration_number in range(total_log_files):
     exp_checkpointer.save_checkpoint(iteration_number, data)
   for iteration_number in range(total_log_files):
     prefixes = [custom_prefix, 'sentinel_checkpoint_complete']
     for prefix in prefixes:
       checkpoint_file = os.path.join(self._test_subdir, '{}.{}'.format(
           prefix, iteration_number))
       if iteration_number < deleted_log_files:
         self.assertFalse(tf.gfile.Exists(checkpoint_file))
       else:
         self.assertTrue(tf.gfile.Exists(checkpoint_file))
Esempio n. 12
0
    def _initialize_checkpointer_and_maybe_resume(self,
                                                  checkpoint_file_prefix):
        """Reloads the latest checkpoint if it exists.

    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoint in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the agent is able to successfully unbundle, this method will increase and
    return the iteration number keyed by 'current_iteration' and the step number
    keyed by 'total_steps' as the return values.

    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.
    Returns:
      start_iteration: The iteration number to be continued after the latest
        checkpoint.
      start_step: The step number to be continued after the latest checkpoint.
    """
        self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                       checkpoint_file_prefix)
        start_iteration = 0
        start_step = 0
        # Check if checkpoint exists.
        # Note that the existence of checkpoint 0 means that we have finished
        # iteration 0 (so we will start from iteration 1).
        latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
            self._checkpoint_dir)
        if latest_checkpoint_version >= 0:
            assert not self._episode_writer, 'Can only log episodes from scratch.'
            experiment_data = self._checkpointer.load_checkpoint(
                latest_checkpoint_version)
            start_iteration = experiment_data['current_iteration'] + 1
            del experiment_data['current_iteration']
            start_step = experiment_data['total_steps'] + 1
            del experiment_data['total_steps']
            if self._agent.unbundle(self._checkpoint_dir,
                                    latest_checkpoint_version,
                                    experiment_data):
                tf.logging.info(
                    'Reloaded checkpoint and will start from '
                    'iteration %d', start_iteration)
        return start_iteration, start_step
Esempio n. 13
0
 def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds(
     self, mock_get_latest):
   latest_checkpoint = 7
   mock_get_latest.return_value = latest_checkpoint
   logs_data = {'a': 1, 'b': 2}
   current_iteration = 1729
   checkpoint_data = {'current_iteration': current_iteration,
                      'logs': logs_data}
   checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints')
   checkpoint = checkpointer.Checkpointer(checkpoint_dir, 'ckpt')
   checkpoint.save_checkpoint(latest_checkpoint, checkpoint_data)
   mock_agent = mock.Mock()
   mock_agent.unbundle.return_value = True
   runner = run_experiment.Runner(self._test_subdir,
                                  lambda x, y, summary_writer: mock_agent,
                                  mock.Mock)
   expected_iteration = current_iteration + 1
   self.assertEqual(expected_iteration, runner._start_iteration)
   self.assertDictEqual(logs_data, runner._logger.data)
   mock_agent.unbundle.assert_called_once_with(
       checkpoint_dir, latest_checkpoint, checkpoint_data)
 def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
   self._agent.reload_checkpoint(
       self._trained_agent_checkpoint_path)
   self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                  checkpoint_file_prefix)
   self._start_iteration = 0
Esempio n. 15
0
def run_training(environment, exploration_fn, random_state, args_dict=None):
    """Executes the training loop.

  Args:
    environment: Instance of Environment.
    exploration_fn: function, this can be linear decay or constant exploration
      rate.
    random_state: np.random.RandomState, maintain the random generator state.
    args_dict: dictionary, contains all the necessary configuration params.
  """
    last_layer_weights = environment.last_layer_weights(),
    last_layer_biases = environment.last_layer_biases(),
    last_layer_target_weights = environment.last_layer_target_weights(),
    last_layer_target_biases = environment.last_layer_target_biases(),
    num_actions = args_dict["num_actions"]
    state_dimensions = args_dict["state_dimensions"]

    checkpoint_dir = os.path.join(FLAGS.checkpoint_path, "checkpoint/")
    checkpoint_handler = checkpointer.Checkpointer(
        checkpoint_dir, checkpoint_frequency=CHECKPOINT_FREQUENCY)
    checkpoint_version = checkpointer.get_latest_checkpoint_number(
        checkpoint_dir)
    if checkpoint_version >= 0:
        checkpoint_data = checkpoint_handler.load_checkpoint(
            checkpoint_version)
        print(f"Restored checkpoint for iteration {checkpoint_version}.")
        # TODO(dijiasu): Revisit if need to run agent._sync_qt_ops().
    else:
        print("No checkpoint found. Initializing all variables.")
        # Initialize CheckpointData object if there is not checkpoint yet.
        target_weight = [last_layer_target_weights, last_layer_target_biases]
        pre_load = [last_layer_weights, last_layer_biases]
        node_queue = queue.Queue()
        if abs(args_dict["consistency_coeff"]) <= EPS_TOLERANCE:
            for _ in range(args_dict["max_num_nodes"]):
                node_queue.put(
                    (pre_load, [0], target_weight, target_weight, 0, 0, 0, 0))
        else:
            node_queue.put(
                (pre_load, [0], target_weight, target_weight, 0, 0, 0, 0))

        checkpoint_data = CheckpointData(
            batch_feature_vecs=np.zeros(
                (args_dict["steps_per_iteration"], state_dimensions),
                dtype="f8"),
            target_weight=[
                last_layer_target_weights, last_layer_target_biases
            ],
            tree_exp_replay=ConsistencyBuffer(random_state, args_dict),
            pre_load=pre_load,
            node_queue=node_queue,
            tree_level_size=node_queue.qsize(),
            weights_to_be_rollout=pre_load,
            best_sampler=None,
            back_tracking_count=args_dict["back_tracking_count"],
            back_tracking_node_list=args_dict["back_tracking_node_list"],
            start_i=0,
            optimizer=tf.train.RMSPropOptimizer(
                learning_rate=FLAGS.learning_rate,
                decay=0.95,
                epsilon=0.00001,
                centered=True),
            initial_batch_data=None)

        # Check initial performance.
        # TODO(dijiasu): Conside remove.
        rollout_layer = create_single_layer_from_weights(
            num_actions, state_dimensions, pre_load)
        (avg_actual_return, avg_predicted_return,
         avg_q_val) = environment.evaluate_policy(
             random_state, rollout_layer, epsilon_eval=args_dict["eval_eps"])
        print("initial batch #%d true_val: %.2f predicted_val: %.2f\n", 0,
              avg_actual_return, avg_predicted_return)
        for iteration_children in range(args_dict["max_num_nodes"]):
            with tf.name_scope("children={}_queue={}_lr={}_samp={}".format(
                    args_dict["num_children"], args_dict["max_num_nodes"],
                    args_dict["learning_rate"],
                    args_dict["sample_consis_buffer"])):
                with tf.name_scope(
                        "children_cnt={}".format(iteration_children)):
                    tf.compat.v2.summary.scalar("actual_return",
                                                avg_actual_return,
                                                step=0)
                    tf.compat.v2.summary.scalar("predic_return",
                                                avg_predicted_return,
                                                step=0)
        with tf.name_scope("children={}_queue={}_lr={}_samp={}".format(
                args_dict["num_children"], args_dict["max_num_nodes"],
                args_dict["learning_rate"],
                args_dict["sample_consis_buffer"])):
            with tf.name_scope("Best!"):
                tf.compat.v2.summary.scalar("best_actual_return",
                                            avg_actual_return,
                                            step=0)
                tf.compat.v2.summary.scalar("indice", 0, step=0)

    batch_feature_vecs = checkpoint_data.batch_feature_vecs
    target_weight = checkpoint_data.target_weight
    tree_exp_replay = checkpoint_data.tree_exp_replay
    pre_load = checkpoint_data.pre_load
    q = checkpoint_data.node_queue
    tree_level_size = checkpoint_data.tree_level_size
    weights_to_be_rollout = checkpoint_data.weights_to_be_rollout
    best_sampler = checkpoint_data.best_sampler
    back_tracking_count = checkpoint_data.back_tracking_count
    back_tracking_node_list = checkpoint_data.back_tracking_node_list
    start_i = checkpoint_data.start_i
    optimizer = checkpoint_data.optimizer
    initial_batch_data = checkpoint_data.initial_batch_data
    for i in range(start_i, FLAGS.training_iterations):
        print(f"Starting iteration {i}.")
        level_weights = []
        target_layer_list = []
        for level in range(tree_level_size):
            level_q = q.get()
            (parent_weight, parent_index, target_weight, old_target_weight, _,
             _, _, _) = level_q
            old_target = create_single_layer_from_weights(
                num_actions, state_dimensions, old_target_weight)
            target_layer_list.append(old_target)
            q.put(level_q)

        # Launch the agent and sample one batch of data from the env
        if best_sampler is not None:
            single_batch = exploration_fn(
                random_state,
                create_single_layer_from_weights(num_actions, state_dimensions,
                                                 best_sampler),
                target_layer_list)
        else:
            single_batch = exploration_fn(
                random_state,
                create_single_layer_from_weights(num_actions, state_dimensions,
                                                 pre_load), target_layer_list)
        if initial_batch_data is None:
            initial_batch_data = process_initial_batch(single_batch, args_dict)

        for level in range(tree_level_size):
            (parent_weight, parent_index, target_weight, old_target_weight, _,
             _, _, parent_score) = q.get()

            # We update the target weights
            if i > 0 and i % args_dict["target_replacement_freq"] == 0:
                old_target_weight = target_weight

            # We are doing Q-update here, and split the parent nodes into multiple
            # child nodes
            children_weights = update_single_step(
                batch_feature_vecs, random_state, parent_weight,
                tree_level_size, tree_exp_replay, parent_index,
                old_target_weight, i, single_batch, level, parent_score,
                optimizer, args_dict)

            for children_w in children_weights:
                q.put(children_w)
                level_weights.append(children_w)

        # Advance the experience buffer
        tree_exp_replay.next_level()
        # Deleting previous level experience
        tree_exp_replay.forget()
        tree_level_size = q.qsize()

        if i % args_dict["rollout_freq"] == 0:
            children_cnt = 0
            actual_return_ls = []
            children_score = []
            children_loss_bellman = []
            children_loss_regularize = []
            children_loss_total = []
            parent_loss_vector = []
            children_avg_q = []

            num_best_nodes_for_expansion = args_dict[
                "num_best_nodes_for_expansion"]
            for children_w in level_weights:
                with tf.name_scope("children={}_queue={}_lr={}_samp={}".format(
                        args_dict["num_children"], args_dict["max_num_nodes"],
                        args_dict["learning_rate"],
                        args_dict["sample_consis_buffer"])):
                    with tf.name_scope("children_cnt={}".format(children_cnt)):
                        (chid_weight_np, _, _, _, total_bellman,
                         total_regularized_loss, total_q,
                         parent_score) = children_w
                        weights_to_be_rollout = chid_weight_np
                        children_loss_bellman.append(total_bellman)
                        children_loss_regularize.append(total_regularized_loss)
                        children_loss_total.append(total_bellman +
                                                   total_regularized_loss)
                        children_avg_q.append(total_q -
                                              (total_bellman +
                                               total_regularized_loss))
                        parent_loss_vector.append(parent_score)
                        rollout_layer = create_single_layer_from_weights(
                            num_actions, state_dimensions,
                            weights_to_be_rollout)

                        # Evaluate the policy
                        (avg_actual_return, avg_predicted_return,
                         avg_q_val) = environment.evaluate_policy(
                             random_state,
                             rollout_layer,
                             epsilon_eval=args_dict["eval_eps"])

                        tf.compat.v2.summary.scalar("actual_return",
                                                    avg_actual_return,
                                                    step=i + 1)
                        tf.compat.v2.summary.scalar("predic_return",
                                                    avg_predicted_return,
                                                    step=i + 1)
                        tf.compat.v2.summary.scalar("avg_q",
                                                    avg_q_val,
                                                    step=i + 1)
                        actual_return_ls.append(avg_actual_return)
                        children_cnt += 1
                        tf.logging.info(
                            "batch #%d true_val: %.2f predicted_val: %.2f\n",
                            i, avg_actual_return, avg_predicted_return)
                        children_score.append(avg_actual_return)

            children_loss_bellman = np.array(children_loss_bellman)
            children_loss_regularize = np.array(children_loss_regularize)
            children_loss_total = np.array(children_loss_total)
            parent_loss_vector = np.array(parent_loss_vector)
            children_avg_q = np.array(children_avg_q)
            children_score = np.array(children_score)
            children_score_idx = children_score.argsort(
            )[-num_best_nodes_for_expansion:][::-1]
            chosen_children = []

            # Choose scoring function for selecting which nodes to expand
            if args_dict["node_scoring_fn"] == "rollouts":
                # This uses the rollouts as the scoring function
                if (i > 0) and (args_dict["consistency_coeff"] >
                                0) and args_dict["enable_back_tracking"]:
                    for track_cell in back_tracking_node_list:
                        np.append(children_score, track_cell[-1])
                children_score_idx = children_score.argsort(
                )[-num_best_nodes_for_expansion:][::-1]
                chosen_children = children_score.argsort()[::-1]

            elif args_dict["node_scoring_fn"] == "bellman_consistency":
                # Using bellman_consistency as the scoring function
                if (i > 0) and (args_dict["consistency_coeff"] >
                                0) and args_dict["enable_back_tracking"]:
                    mean_loss_total_delta = np.mean(children_loss_total -
                                                    parent_loss_vector)
                    for track_cell in range(len(back_tracking_node_list)):
                        back_tracking_node_list[track_cell][
                            -1] = back_tracking_node_list[track_cell][
                                -1] + mean_loss_total_delta * BACK_TRACK_CALIBRATION
                        children_loss_total = np.append(
                            children_loss_total,
                            back_tracking_node_list[track_cell][-1])
                children_score_idx = children_loss_total.argsort(
                )[:num_best_nodes_for_expansion]
                chosen_children = children_loss_total.argsort()

            try:
                if children_score_idx[0] >= args_dict["max_num_nodes"]:
                    best_sampler = back_tracking_node_list[
                        children_score_idx[0] - args_dict["max_num_nodes"]][0]
            except IndexError as e:
                tf.logging.error(e)
            if i > 0 and args_dict["consistency_coeff"] > 0:
                for queue_iteration in range(len(level_weights)):
                    q.get()
                tf.logging.info("Pruning nodes...")
                if args_dict["enable_back_tracking"]:
                    for queue_iteration in range(len(level_weights)):
                        level_weights[queue_iteration] = list(
                            level_weights[queue_iteration])
                        if args_dict["node_scoring_fn"] == "rollouts":
                            level_weights[queue_iteration][
                                -1] = children_score[queue_iteration]
                        elif args_dict[
                                "node_scoring_fn"] == "bellman_consistency":
                            level_weights[queue_iteration][
                                -1] = level_weights[queue_iteration][
                                    -4] + level_weights[queue_iteration][-3]
                        level_weights[queue_iteration] = tuple(
                            level_weights[queue_iteration])
                # Queue_iteration is the indices that perform the best according to
                # Chosen scoring function
                for queue_iteration in children_score_idx:
                    if args_dict["enable_back_tracking"]:
                        if queue_iteration >= args_dict["max_num_nodes"]:
                            adjust_idx = queue_iteration - args_dict[
                                "max_num_nodes"]
                            q.put(tuple(back_tracking_node_list[adjust_idx]))
                            back_tracking_count += 1
                        else:
                            q.put(level_weights[queue_iteration])
                    else:
                        q.put(level_weights[queue_iteration])
                # Pick the highest scoring nodes
                if args_dict["enable_back_tracking"]:
                    tmp_back_tracking_node_list = []
                    for idx__ in chosen_children[
                            num_best_nodes_for_expansion:]:
                        if idx__ >= args_dict["max_num_nodes"]:
                            adjust_idx = idx__ - args_dict["max_num_nodes"]
                            tmp_back_tracking_node_list.append(
                                back_tracking_node_list[adjust_idx])
                        else:
                            tmp_back_tracking_node_list.append(
                                list(level_weights[idx__]))

                        if len(tmp_back_tracking_node_list
                               ) >= args_dict["back_track_size"]:
                            break
                    back_tracking_node_list = tmp_back_tracking_node_list
                tree_level_size = num_best_nodes_for_expansion

            with tf.name_scope("children={}_queue={}_lr={}_samp={}".format(
                    args_dict["num_children"], args_dict["max_num_nodes"],
                    args_dict["learning_rate"],
                    args_dict["sample_consis_buffer"])):
                with tf.name_scope("Best!"):
                    tf.compat.v2.summary.scalar("best_actual_return",
                                                np.max(actual_return_ls),
                                                step=i + 1)
                    # TODO(dijiasu): Check if renaming from indice to index breaks this.
                    tf.compat.v2.summary.scalar("indice",
                                                np.argmax(actual_return_ls),
                                                step=i + 1)
        if i == 0:
            tf.logging.info(
                "Copying the online network weights to target network.")
            # TODO(dijiasu): Revisit if need to run agent._sync_qt_ops().

        # TODO(dijiasu): Keep checkpoint_data variable to avoid re-creating it.
        checkpoint_data = CheckpointData(
            batch_feature_vecs=batch_feature_vecs,
            target_weight=target_weight,
            tree_exp_replay=tree_exp_replay,
            pre_load=pre_load,
            q=q,
            tree_level_size=tree_level_size,
            weights_to_be_rollout=weights_to_be_rollout,
            best_sampler=best_sampler,
            back_tracking_count=back_tracking_count,
            back_tracking_node_list=back_tracking_node_list,
            start_i=i + 1,
            optimizer=optimizer,
            initial_batch_data=initial_batch_data,
        )
        checkpoint_handler.save_checkpoint(i, checkpoint_data)
        args_dict["back_tracking_count"] = back_tracking_count
        args_dict["back_tracking_node_list"] = back_tracking_node_list