def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoiRnt in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the agent is able to successfully unbundle, this method will verify that the unbundled data contains the keys,'logs' and 'current_iteration'. It will then load the `Logger`'s data from the bundle, and will return the iteration number keyed by 'current_iteration' as one of the return values (along with the `Checkpointer` object). Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: int, the iteration number to start the experiment from. experiment_checkpointer: `Checkpointer` object for the experiment. """ self._checkpointer = checkpointer.Checkpointer( self.PATH + "checkpoints", checkpoint_file_prefix) self._checkpointer2 = checkpointer.Checkpointer( self.PATH + "player2", checkpoint_file_prefix) self._checkpointerlatest1 = checkpointer.Checkpointer( self.PATH + "latest3", checkpoint_file_prefix) self._checkpointerlatest2 = checkpointer.Checkpointer( self.PATH + "latest4", checkpoint_file_prefix) self._checkpointertest = checkpointer.Checkpointer( self.PATH + "test", checkpoint_file_prefix) self._start_iteration = 0 # Check if checkpoint exists. Note that the existence of checkpoint 0 means # that we have finished iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self.PATH + "checkpoints") print("Latest version:", latest_checkpoint_version) return 5 if latest_checkpoint_version >= 0: experiment_data = self._checkpointer.load_checkpoint( latest_checkpoint_version) print("Read to apple") if self._agent.unbundle(self.PATH + "checkpoints", latest_checkpoint_version, experiment_data): assert 'logs' in experiment_data assert 'current_iteration' in experiment_data self._logger.data = experiment_data['logs'] self._start_iteration = experiment_data['current_iteration'] + 1 tf.logging.info( 'Reloaded checkpoint and will start from iteration %d', self._start_iteration) #self.testing = True #self._agent.testing = True print('TESTING:', self.testing) print("BANANA")
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): super(FixedReplayRunner, self)._initialize_checkpointer_and_maybe_resume( checkpoint_file_prefix) # Code for the loading a checkpoint at initialization init_checkpoint_dir = (self._agent._init_checkpoint_dir) # pylint: disable=protected-access if (self._start_iteration == 0) and (init_checkpoint_dir is not None): if checkpointer.get_latest_checkpoint_number( self._checkpoint_dir) < 0: # No checkpoint loaded yet, read init_checkpoint_dir init_checkpointer = checkpointer.Checkpointer( init_checkpoint_dir, checkpoint_file_prefix) latest_init_checkpoint = checkpointer.get_latest_checkpoint_number( init_checkpoint_dir) if latest_init_checkpoint >= 0: experiment_data = init_checkpointer.load_checkpoint( latest_init_checkpoint) if self._agent.unbundle(init_checkpoint_dir, latest_init_checkpoint, experiment_data): if experiment_data is not None: assert "logs" in experiment_data assert "current_iteration" in experiment_data self._logger.data = experiment_data["logs"] self._start_iteration = ( experiment_data["current_iteration"] + 1) tf.logging.info( "Reloaded checkpoint from %s and will start from iteration %d", init_checkpoint_dir, self._start_iteration, )
def run_experiment(self): """Runs a full experiment, spread over multiple iterations.""" tf.logging.info('Beginning evaluation...') # Use the checkpointer class. self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, self._checkpoint_file_prefix) checkpoint_version = -1 # Check new checkpoints in a loop. while True: # Check if checkpoint exists. # Note that the existence of checkpoint 0 means that we have finished # iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self._checkpoint_dir) # checkpoints_iterator already makes sure a new checkpoint exists. if latest_checkpoint_version <= checkpoint_version: time.sleep(self._min_interval_secs) continue checkpoint_version = latest_checkpoint_version experiment_data = self._checkpointer.load_checkpoint( latest_checkpoint_version) assert self._agent.unbundle(self._checkpoint_dir, latest_checkpoint_version, experiment_data) self._run_eval_phase(experiment_data['total_steps']) if self._test_mode: break
def testGarbageCollectionWithCheckpointFrequency(self): custom_prefix = 'custom_prefix' checkpoint_frequency = 3 exp_checkpointer = checkpointer.Checkpointer( self._test_subdir, checkpoint_file_prefix=custom_prefix, checkpoint_frequency=checkpoint_frequency) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} deleted_log_files = 6 total_log_files = (checkpointer.CHECKPOINT_DURATION * checkpoint_frequency) + deleted_log_files + 1 # The checkpoints will happen in iteration numbers 0,3,6,9,12,15,18. # We are checking if checkpoints 0,3,6 are deleted. for iteration_number in range(total_log_files): exp_checkpointer.save_checkpoint(iteration_number, data) for iteration_number in range(total_log_files): prefixes = [custom_prefix, 'sentinel_checkpoint_complete'] for prefix in prefixes: checkpoint_file = os.path.join( self._test_subdir, '{}.{}'.format(prefix, iteration_number)) if iteration_number <= deleted_log_files: self.assertFalse(tf.gfile.Exists(checkpoint_file)) else: if iteration_number % checkpoint_frequency == 0: self.assertTrue(tf.gfile.Exists(checkpoint_file)) else: self.assertFalse(tf.gfile.Exists(checkpoint_file))
def initialize_checkpointer(checkpoint_dir, checkpoint_file_prefix, agent): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoint in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the agent is able to successfully unbundle, this method will verify that the unbundled data contains the keys,'logs' and 'current_iteration'. It will then load the `Logger`'s data from the bundle, and will return the iteration number keyed by 'current_iteration' as one of the return values (along with the `Checkpointer` object). Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: int, the iteration number to start the experiment from. experiment_checkpointer: `Checkpointer` object for the experiment. """ checkpointer_ = checkpointer.Checkpointer(checkpoint_dir, checkpoint_file_prefix) start_iteration = 0 # Check if checkpoint exists. Note that the existence of checkpoint 0 means # that we have finished iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( checkpoint_dir) if latest_checkpoint_version >= 0: experiment_data = checkpointer_.load_checkpoint( latest_checkpoint_version) agent.unbundle(checkpoint_dir, latest_checkpoint_version, experiment_data)
def testCheckpointingInitialization(self): # Fails with empty directory. with self.assertRaisesRegexp(ValueError, 'No path provided to Checkpointer.'): checkpointer.Checkpointer('') # Fails with invalid directory. invalid_dir = '/does/not/exist' with self.assertRaisesRegexp( ValueError, 'Unable to create checkpoint path: {}.'.format(invalid_dir)): checkpointer.Checkpointer(invalid_dir) # Succeeds with valid directory. checkpointer.Checkpointer('/tmp/dopamine_tests') # This verifies initialization still works after the directory has already # been created. self.assertTrue(tf.gfile.Exists('/tmp/dopamine_tests')) checkpointer.Checkpointer('/tmp/dopamine_tests')
def testLoadLatestCheckpoint(self): exp_checkpointer = checkpointer.Checkpointer(self._test_subdir) first_iter = 1729 exp_checkpointer.save_checkpoint(first_iter, first_iter) second_iter = first_iter + 1 exp_checkpointer.save_checkpoint(second_iter, second_iter) self.assertEqual( second_iter, checkpointer.get_latest_checkpoint_number(self._test_subdir))
def testLogToFileWithValidDirectoryDefaultPrefix(self): exp_checkpointer = checkpointer.Checkpointer(self._test_subdir) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} iteration_number = 1729 exp_checkpointer.save_checkpoint(iteration_number, data) loaded_data = exp_checkpointer.load_checkpoint(iteration_number) self.assertEqual(data, loaded_data) self.assertEqual( None, exp_checkpointer.load_checkpoint(iteration_number + 1))
def testLogToFileWithValidDirectoryCustomPrefix(self): prefix = 'custom_prefix' exp_checkpointer = checkpointer.Checkpointer( self._test_subdir, checkpoint_file_prefix=prefix) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} iteration_number = 1729 exp_checkpointer.save_checkpoint(iteration_number, data) loaded_data = exp_checkpointer.load_checkpoint(iteration_number) self.assertEqual(data, loaded_data) self.assertIsNone( exp_checkpointer.load_checkpoint(iteration_number + 1))
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoint in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the agent is able to successfully unbundle, this method will verify that the unbundled data contains the keys,'logs' and 'current_iteration'. It will then load the `Logger`'s data from the bundle, and will return the iteration number keyed by 'current_iteration' as one of the return values (along with the `Checkpointer` object). Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: int, the iteration number to start the experiment from. experiment_checkpointer: `Checkpointer` object for the experiment. """ self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, checkpoint_file_prefix) self._start_iteration = 0 # Check if checkpoint exists. Note that the existence of checkpoint 0 means # that we have finished iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self._checkpoint_dir) if latest_checkpoint_version >= 0: experiment_data = self._checkpointer.load_checkpoint( latest_checkpoint_version) if self._agent.unbundle(self._checkpoint_dir, latest_checkpoint_version, experiment_data): if experiment_data is not None: assert 'logs' in experiment_data assert 'current_iteration' in experiment_data self._logger.data = experiment_data['logs'] self._start_iteration = experiment_data[ 'current_iteration'] + 1 if self._environment.game_name[0:4] == 'VGDL': self._environment.set_level( experiment_data['vgdl_level'], experiment_data['training_steps']) tf.logging.info( 'Reloaded checkpoint and will start from iteration %d', self._start_iteration)
def testGarbageCollection(self): custom_prefix = 'custom_prefix' exp_checkpointer = checkpointer.Checkpointer( self._test_subdir, checkpoint_file_prefix=custom_prefix) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} deleted_log_files = 7 total_log_files = checkpointer.CHECKPOINT_DURATION + deleted_log_files for iteration_number in range(total_log_files): exp_checkpointer.save_checkpoint(iteration_number, data) for iteration_number in range(total_log_files): prefixes = [custom_prefix, 'sentinel_checkpoint_complete'] for prefix in prefixes: checkpoint_file = os.path.join(self._test_subdir, '{}.{}'.format( prefix, iteration_number)) if iteration_number < deleted_log_files: self.assertFalse(tf.gfile.Exists(checkpoint_file)) else: self.assertTrue(tf.gfile.Exists(checkpoint_file))
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoint in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the agent is able to successfully unbundle, this method will increase and return the iteration number keyed by 'current_iteration' and the step number keyed by 'total_steps' as the return values. Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: The iteration number to be continued after the latest checkpoint. start_step: The step number to be continued after the latest checkpoint. """ self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, checkpoint_file_prefix) start_iteration = 0 start_step = 0 # Check if checkpoint exists. # Note that the existence of checkpoint 0 means that we have finished # iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self._checkpoint_dir) if latest_checkpoint_version >= 0: assert not self._episode_writer, 'Can only log episodes from scratch.' experiment_data = self._checkpointer.load_checkpoint( latest_checkpoint_version) start_iteration = experiment_data['current_iteration'] + 1 del experiment_data['current_iteration'] start_step = experiment_data['total_steps'] + 1 del experiment_data['total_steps'] if self._agent.unbundle(self._checkpoint_dir, latest_checkpoint_version, experiment_data): tf.logging.info( 'Reloaded checkpoint and will start from ' 'iteration %d', start_iteration) return start_iteration, start_step
def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds( self, mock_get_latest): latest_checkpoint = 7 mock_get_latest.return_value = latest_checkpoint logs_data = {'a': 1, 'b': 2} current_iteration = 1729 checkpoint_data = {'current_iteration': current_iteration, 'logs': logs_data} checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints') checkpoint = checkpointer.Checkpointer(checkpoint_dir, 'ckpt') checkpoint.save_checkpoint(latest_checkpoint, checkpoint_data) mock_agent = mock.Mock() mock_agent.unbundle.return_value = True runner = run_experiment.Runner(self._test_subdir, lambda x, y, summary_writer: mock_agent, mock.Mock) expected_iteration = current_iteration + 1 self.assertEqual(expected_iteration, runner._start_iteration) self.assertDictEqual(logs_data, runner._logger.data) mock_agent.unbundle.assert_called_once_with( checkpoint_dir, latest_checkpoint, checkpoint_data)
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): self._agent.reload_checkpoint( self._trained_agent_checkpoint_path) self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, checkpoint_file_prefix) self._start_iteration = 0
def run_training(environment, exploration_fn, random_state, args_dict=None): """Executes the training loop. Args: environment: Instance of Environment. exploration_fn: function, this can be linear decay or constant exploration rate. random_state: np.random.RandomState, maintain the random generator state. args_dict: dictionary, contains all the necessary configuration params. """ last_layer_weights = environment.last_layer_weights(), last_layer_biases = environment.last_layer_biases(), last_layer_target_weights = environment.last_layer_target_weights(), last_layer_target_biases = environment.last_layer_target_biases(), num_actions = args_dict["num_actions"] state_dimensions = args_dict["state_dimensions"] checkpoint_dir = os.path.join(FLAGS.checkpoint_path, "checkpoint/") checkpoint_handler = checkpointer.Checkpointer( checkpoint_dir, checkpoint_frequency=CHECKPOINT_FREQUENCY) checkpoint_version = checkpointer.get_latest_checkpoint_number( checkpoint_dir) if checkpoint_version >= 0: checkpoint_data = checkpoint_handler.load_checkpoint( checkpoint_version) print(f"Restored checkpoint for iteration {checkpoint_version}.") # TODO(dijiasu): Revisit if need to run agent._sync_qt_ops(). else: print("No checkpoint found. Initializing all variables.") # Initialize CheckpointData object if there is not checkpoint yet. target_weight = [last_layer_target_weights, last_layer_target_biases] pre_load = [last_layer_weights, last_layer_biases] node_queue = queue.Queue() if abs(args_dict["consistency_coeff"]) <= EPS_TOLERANCE: for _ in range(args_dict["max_num_nodes"]): node_queue.put( (pre_load, [0], target_weight, target_weight, 0, 0, 0, 0)) else: node_queue.put( (pre_load, [0], target_weight, target_weight, 0, 0, 0, 0)) checkpoint_data = CheckpointData( batch_feature_vecs=np.zeros( (args_dict["steps_per_iteration"], state_dimensions), dtype="f8"), target_weight=[ last_layer_target_weights, last_layer_target_biases ], tree_exp_replay=ConsistencyBuffer(random_state, args_dict), pre_load=pre_load, node_queue=node_queue, tree_level_size=node_queue.qsize(), weights_to_be_rollout=pre_load, best_sampler=None, back_tracking_count=args_dict["back_tracking_count"], back_tracking_node_list=args_dict["back_tracking_node_list"], start_i=0, optimizer=tf.train.RMSPropOptimizer( learning_rate=FLAGS.learning_rate, decay=0.95, epsilon=0.00001, centered=True), initial_batch_data=None) # Check initial performance. # TODO(dijiasu): Conside remove. rollout_layer = create_single_layer_from_weights( num_actions, state_dimensions, pre_load) (avg_actual_return, avg_predicted_return, avg_q_val) = environment.evaluate_policy( random_state, rollout_layer, epsilon_eval=args_dict["eval_eps"]) print("initial batch #%d true_val: %.2f predicted_val: %.2f\n", 0, avg_actual_return, avg_predicted_return) for iteration_children in range(args_dict["max_num_nodes"]): with tf.name_scope("children={}_queue={}_lr={}_samp={}".format( args_dict["num_children"], args_dict["max_num_nodes"], args_dict["learning_rate"], args_dict["sample_consis_buffer"])): with tf.name_scope( "children_cnt={}".format(iteration_children)): tf.compat.v2.summary.scalar("actual_return", avg_actual_return, step=0) tf.compat.v2.summary.scalar("predic_return", avg_predicted_return, step=0) with tf.name_scope("children={}_queue={}_lr={}_samp={}".format( args_dict["num_children"], args_dict["max_num_nodes"], args_dict["learning_rate"], args_dict["sample_consis_buffer"])): with tf.name_scope("Best!"): tf.compat.v2.summary.scalar("best_actual_return", avg_actual_return, step=0) tf.compat.v2.summary.scalar("indice", 0, step=0) batch_feature_vecs = checkpoint_data.batch_feature_vecs target_weight = checkpoint_data.target_weight tree_exp_replay = checkpoint_data.tree_exp_replay pre_load = checkpoint_data.pre_load q = checkpoint_data.node_queue tree_level_size = checkpoint_data.tree_level_size weights_to_be_rollout = checkpoint_data.weights_to_be_rollout best_sampler = checkpoint_data.best_sampler back_tracking_count = checkpoint_data.back_tracking_count back_tracking_node_list = checkpoint_data.back_tracking_node_list start_i = checkpoint_data.start_i optimizer = checkpoint_data.optimizer initial_batch_data = checkpoint_data.initial_batch_data for i in range(start_i, FLAGS.training_iterations): print(f"Starting iteration {i}.") level_weights = [] target_layer_list = [] for level in range(tree_level_size): level_q = q.get() (parent_weight, parent_index, target_weight, old_target_weight, _, _, _, _) = level_q old_target = create_single_layer_from_weights( num_actions, state_dimensions, old_target_weight) target_layer_list.append(old_target) q.put(level_q) # Launch the agent and sample one batch of data from the env if best_sampler is not None: single_batch = exploration_fn( random_state, create_single_layer_from_weights(num_actions, state_dimensions, best_sampler), target_layer_list) else: single_batch = exploration_fn( random_state, create_single_layer_from_weights(num_actions, state_dimensions, pre_load), target_layer_list) if initial_batch_data is None: initial_batch_data = process_initial_batch(single_batch, args_dict) for level in range(tree_level_size): (parent_weight, parent_index, target_weight, old_target_weight, _, _, _, parent_score) = q.get() # We update the target weights if i > 0 and i % args_dict["target_replacement_freq"] == 0: old_target_weight = target_weight # We are doing Q-update here, and split the parent nodes into multiple # child nodes children_weights = update_single_step( batch_feature_vecs, random_state, parent_weight, tree_level_size, tree_exp_replay, parent_index, old_target_weight, i, single_batch, level, parent_score, optimizer, args_dict) for children_w in children_weights: q.put(children_w) level_weights.append(children_w) # Advance the experience buffer tree_exp_replay.next_level() # Deleting previous level experience tree_exp_replay.forget() tree_level_size = q.qsize() if i % args_dict["rollout_freq"] == 0: children_cnt = 0 actual_return_ls = [] children_score = [] children_loss_bellman = [] children_loss_regularize = [] children_loss_total = [] parent_loss_vector = [] children_avg_q = [] num_best_nodes_for_expansion = args_dict[ "num_best_nodes_for_expansion"] for children_w in level_weights: with tf.name_scope("children={}_queue={}_lr={}_samp={}".format( args_dict["num_children"], args_dict["max_num_nodes"], args_dict["learning_rate"], args_dict["sample_consis_buffer"])): with tf.name_scope("children_cnt={}".format(children_cnt)): (chid_weight_np, _, _, _, total_bellman, total_regularized_loss, total_q, parent_score) = children_w weights_to_be_rollout = chid_weight_np children_loss_bellman.append(total_bellman) children_loss_regularize.append(total_regularized_loss) children_loss_total.append(total_bellman + total_regularized_loss) children_avg_q.append(total_q - (total_bellman + total_regularized_loss)) parent_loss_vector.append(parent_score) rollout_layer = create_single_layer_from_weights( num_actions, state_dimensions, weights_to_be_rollout) # Evaluate the policy (avg_actual_return, avg_predicted_return, avg_q_val) = environment.evaluate_policy( random_state, rollout_layer, epsilon_eval=args_dict["eval_eps"]) tf.compat.v2.summary.scalar("actual_return", avg_actual_return, step=i + 1) tf.compat.v2.summary.scalar("predic_return", avg_predicted_return, step=i + 1) tf.compat.v2.summary.scalar("avg_q", avg_q_val, step=i + 1) actual_return_ls.append(avg_actual_return) children_cnt += 1 tf.logging.info( "batch #%d true_val: %.2f predicted_val: %.2f\n", i, avg_actual_return, avg_predicted_return) children_score.append(avg_actual_return) children_loss_bellman = np.array(children_loss_bellman) children_loss_regularize = np.array(children_loss_regularize) children_loss_total = np.array(children_loss_total) parent_loss_vector = np.array(parent_loss_vector) children_avg_q = np.array(children_avg_q) children_score = np.array(children_score) children_score_idx = children_score.argsort( )[-num_best_nodes_for_expansion:][::-1] chosen_children = [] # Choose scoring function for selecting which nodes to expand if args_dict["node_scoring_fn"] == "rollouts": # This uses the rollouts as the scoring function if (i > 0) and (args_dict["consistency_coeff"] > 0) and args_dict["enable_back_tracking"]: for track_cell in back_tracking_node_list: np.append(children_score, track_cell[-1]) children_score_idx = children_score.argsort( )[-num_best_nodes_for_expansion:][::-1] chosen_children = children_score.argsort()[::-1] elif args_dict["node_scoring_fn"] == "bellman_consistency": # Using bellman_consistency as the scoring function if (i > 0) and (args_dict["consistency_coeff"] > 0) and args_dict["enable_back_tracking"]: mean_loss_total_delta = np.mean(children_loss_total - parent_loss_vector) for track_cell in range(len(back_tracking_node_list)): back_tracking_node_list[track_cell][ -1] = back_tracking_node_list[track_cell][ -1] + mean_loss_total_delta * BACK_TRACK_CALIBRATION children_loss_total = np.append( children_loss_total, back_tracking_node_list[track_cell][-1]) children_score_idx = children_loss_total.argsort( )[:num_best_nodes_for_expansion] chosen_children = children_loss_total.argsort() try: if children_score_idx[0] >= args_dict["max_num_nodes"]: best_sampler = back_tracking_node_list[ children_score_idx[0] - args_dict["max_num_nodes"]][0] except IndexError as e: tf.logging.error(e) if i > 0 and args_dict["consistency_coeff"] > 0: for queue_iteration in range(len(level_weights)): q.get() tf.logging.info("Pruning nodes...") if args_dict["enable_back_tracking"]: for queue_iteration in range(len(level_weights)): level_weights[queue_iteration] = list( level_weights[queue_iteration]) if args_dict["node_scoring_fn"] == "rollouts": level_weights[queue_iteration][ -1] = children_score[queue_iteration] elif args_dict[ "node_scoring_fn"] == "bellman_consistency": level_weights[queue_iteration][ -1] = level_weights[queue_iteration][ -4] + level_weights[queue_iteration][-3] level_weights[queue_iteration] = tuple( level_weights[queue_iteration]) # Queue_iteration is the indices that perform the best according to # Chosen scoring function for queue_iteration in children_score_idx: if args_dict["enable_back_tracking"]: if queue_iteration >= args_dict["max_num_nodes"]: adjust_idx = queue_iteration - args_dict[ "max_num_nodes"] q.put(tuple(back_tracking_node_list[adjust_idx])) back_tracking_count += 1 else: q.put(level_weights[queue_iteration]) else: q.put(level_weights[queue_iteration]) # Pick the highest scoring nodes if args_dict["enable_back_tracking"]: tmp_back_tracking_node_list = [] for idx__ in chosen_children[ num_best_nodes_for_expansion:]: if idx__ >= args_dict["max_num_nodes"]: adjust_idx = idx__ - args_dict["max_num_nodes"] tmp_back_tracking_node_list.append( back_tracking_node_list[adjust_idx]) else: tmp_back_tracking_node_list.append( list(level_weights[idx__])) if len(tmp_back_tracking_node_list ) >= args_dict["back_track_size"]: break back_tracking_node_list = tmp_back_tracking_node_list tree_level_size = num_best_nodes_for_expansion with tf.name_scope("children={}_queue={}_lr={}_samp={}".format( args_dict["num_children"], args_dict["max_num_nodes"], args_dict["learning_rate"], args_dict["sample_consis_buffer"])): with tf.name_scope("Best!"): tf.compat.v2.summary.scalar("best_actual_return", np.max(actual_return_ls), step=i + 1) # TODO(dijiasu): Check if renaming from indice to index breaks this. tf.compat.v2.summary.scalar("indice", np.argmax(actual_return_ls), step=i + 1) if i == 0: tf.logging.info( "Copying the online network weights to target network.") # TODO(dijiasu): Revisit if need to run agent._sync_qt_ops(). # TODO(dijiasu): Keep checkpoint_data variable to avoid re-creating it. checkpoint_data = CheckpointData( batch_feature_vecs=batch_feature_vecs, target_weight=target_weight, tree_exp_replay=tree_exp_replay, pre_load=pre_load, q=q, tree_level_size=tree_level_size, weights_to_be_rollout=weights_to_be_rollout, best_sampler=best_sampler, back_tracking_count=back_tracking_count, back_tracking_node_list=back_tracking_node_list, start_i=i + 1, optimizer=optimizer, initial_batch_data=initial_batch_data, ) checkpoint_handler.save_checkpoint(i, checkpoint_data) args_dict["back_tracking_count"] = back_tracking_count args_dict["back_tracking_node_list"] = back_tracking_node_list