def log_to_screen(self): # log to screen log = OrderedDict() log["Episode"] = self.current_episode log["Total reward"] = round(self.total_reward_in_current_episode, 2) log["Steps"] = self.total_steps_counter screen.log_dict(log, prefix="Recording")
def improve_reward_model(self, epochs: int): """ Train a reward model to be used by the doubly-robust estimator :param epochs: The total number of epochs to use for training a reward model :return: None """ batch_size = self.ap.network_wrappers['reward_model'].batch_size network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys() # this is fitted from the training dataset for epoch in range(epochs): loss = 0 for i, batch in enumerate(self.call_memory('get_shuffled_data_generator', batch_size)): batch = Batch(batch) current_rewards_prediction_for_all_actions = self.networks['reward_model'].online_network.predict(batch.states(network_keys)) current_rewards_prediction_for_all_actions[range(batch_size), batch.actions()] = batch.rewards() loss += self.networks['reward_model'].train_and_sync_networks( batch.states(network_keys), current_rewards_prediction_for_all_actions)[0] # print(self.networks['reward_model'].online_network.predict(batch.states(network_keys))[0]) log = OrderedDict() log['Epoch'] = epoch log['loss'] = loss / int(self.call_memory('num_transitions_in_complete_episodes') / batch_size) screen.log_dict(log, prefix='Training Reward Model')
def save_checkpoint(self): # create current session's checkpoint directory if self.task_parameters.checkpoint_save_dir is None: self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint') if not os.path.exists(self.task_parameters.checkpoint_save_dir): os.mkdir(self.task_parameters.checkpoint_save_dir) # Create directory structure checkpoint_name = "{}_Step-{}.ckpt".format( self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]) saved_checkpoint_paths = [] for agent_params in self.agents_params: agent_checkpoint_save_dir = os.path.join(self.task_parameters.checkpoint_save_dir, agent_params.name) if not os.path.exists(agent_checkpoint_save_dir): os.mkdir(agent_checkpoint_save_dir) if self.checkpoint_state_updater[agent_params.name] is None: self.checkpoint_state_updater[agent_params.name] = CheckpointStateUpdater(agent_checkpoint_save_dir) agent_checkpoint_path = os.path.join(agent_checkpoint_save_dir, checkpoint_name) if not isinstance(self.task_parameters, DistributedTaskParameters): saved_checkpoint_paths.append(self.checkpoint_saver[agent_params.name].save(self.sess[agent_params.name], agent_checkpoint_path)) else: saved_checkpoint_paths.append(agent_checkpoint_path) if self.num_checkpoints_to_keep < len(self.checkpoint_state_updater[agent_params.name].all_checkpoints): checkpoint_to_delete = self.checkpoint_state_updater[agent_params.name].all_checkpoints[-self.num_checkpoints_to_keep - 1] agent_checkpoint_to_delete = os.path.join(agent_checkpoint_save_dir, checkpoint_to_delete.name) for file in glob.glob("{}*".format(agent_checkpoint_to_delete)): os.remove(file) # this is required in order for agents to save additional information like a DND for example [manager.save_checkpoint(checkpoint_name) for manager in self.level_managers] # Purge Redis memory after saving the checkpoint as Transitions are no longer needed at this point. if hasattr(self, 'memory_backend'): self.memory_backend.memory_purge() # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used if self.task_parameters.export_onnx_graph: self.save_onnx_graph() # write the new checkpoint name to a file to signal this checkpoint has been fully saved for agent_params in self.agents_params: self.checkpoint_state_updater[agent_params.name].update(SingleCheckpoint(self.checkpoint_id, checkpoint_name)) screen.log_dict( OrderedDict([ ("Saving in path", saved_checkpoint_path) for saved_checkpoint_path in saved_checkpoint_paths ]), prefix="Checkpoint" ) self.checkpoint_id += 1 self.last_checkpoint_saving_time = time.time() if hasattr(self, 'data_store_params'): data_store = self.get_data_store(self.data_store_params) data_store.save_to_store()
def train_rnd(self): if self.memory.num_transitions() == 0: return transitions = self.memory.transitions[-self.ap.algorithm.rnd_sample_size:] dataset = Batch(transitions) dataset_order = list(range(dataset.size)) batch_size = self.ap.algorithm.rnd_batch_size for epoch in range(self.ap.algorithm.rnd_optimization_epochs): shuffle(dataset_order) total_loss = 0 total_grads = 0 for i in range(int(dataset.size / batch_size)): start = i * batch_size end = (i + 1) * batch_size batch = Batch(list(np.array(dataset.transitions)[dataset_order[start:end]])) inputs = self.prepare_rnd_inputs(batch) const_embedding = self.networks['constant'].online_network.predict(inputs) res = self.networks['predictor'].train_and_sync_networks(inputs, [const_embedding]) total_loss += res[0] total_grads += res[2] screen.log_dict( OrderedDict([ ("training epoch", epoch), ("dataset size", dataset.size), ("mean loss", total_loss / dataset.size), ("mean gradients", total_grads / dataset.size) ]), prefix="RND Training" )
def save_checkpoint(self): # only the chief process saves checkpoints if self.task_parameters.save_checkpoint_secs \ and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \ and (self.task_parameters.task_index == 0 # distributed or self.task_parameters.task_index is None # single-worker ): checkpoint_path = os.path.join( self.task_parameters.save_checkpoint_dir, "{}_Step-{}.ckpt".format( self.checkpoint_id, self.total_steps_counters[ RunPhase.TRAIN][EnvironmentSteps])) if not isinstance(self.task_parameters, DistributedTaskParameters): saved_checkpoint_path = self.checkpoint_saver.save( self.sess, checkpoint_path) else: saved_checkpoint_path = checkpoint_path # this is required in order for agents to save additional information like a DND for example [ manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers ] screen.log_dict(OrderedDict([ ("Saving in path", saved_checkpoint_path), ]), prefix="Checkpoint") self.checkpoint_id += 1 self.last_checkpoint_saving_time = time.time()
def log_to_screen(self): # log to screen log = OrderedDict() log["Episode"] = self.episode_idx log["Total reward"] = np.round(self.total_reward_in_current_episode, 2) log["Steps"] = self.total_steps_counter screen.log_dict(log, prefix=self.phase.value)
def log_to_screen(self): # log to screen log = OrderedDict() log["Name"] = self.full_name_id if self.task_id is not None: log["Worker"] = self.task_id log["Episode"] = self.current_episode log["Total reward"] = round(self.total_reward_in_current_episode, 2) log["Steps"] = self.total_steps_counter log["Training iteration"] = self.training_iteration screen.log_dict(log, prefix=self.phase.value)
def save_checkpoint(self): # create current session's checkpoint directory if self.task_parameters.checkpoint_save_dir is None: self.task_parameters.checkpoint_save_dir = os.path.join( self.task_parameters.experiment_path, 'checkpoint') if not os.path.exists(self.task_parameters.checkpoint_save_dir): os.mkdir(self.task_parameters.checkpoint_save_dir ) # Create directory structure if self.checkpoint_state_updater is None: self.checkpoint_state_updater = CheckpointStateUpdater( self.task_parameters.checkpoint_save_dir) checkpoint_name = "{}_Step-{}.ckpt".format( self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]) checkpoint_path = os.path.join( self.task_parameters.checkpoint_save_dir, checkpoint_name) if not isinstance(self.task_parameters, DistributedTaskParameters): saved_checkpoint_path = self.checkpoint_saver.save( self.sess, checkpoint_path) else: saved_checkpoint_path = checkpoint_path # this is required in order for agents to save additional information like a DND for example [ manager.save_checkpoint(checkpoint_name) for manager in self.level_managers ] # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used if self.task_parameters.export_onnx_graph: self.save_onnx_graph() # write the new checkpoint name to a file to signal this checkpoint has been fully saved self.checkpoint_state_updater.update( SingleCheckpoint(self.checkpoint_id, checkpoint_name)) screen.log_dict(OrderedDict([ ("Saving in path", saved_checkpoint_path), ]), prefix="Checkpoint") self.checkpoint_id += 1 self.last_checkpoint_saving_time = time.time() if hasattr(self, 'data_store_params'): data_store = self.get_data_store(self.data_store_params) data_store.save_to_store()
def log_to_screen(self): # log to screen if self.phase == RunPhase.TRAIN: # for the training phase - we log during the episode to visualize the progress in training log = OrderedDict() if self.task_id is not None: log["Worker"] = self.task_id log["Episode"] = self.current_episode log["Loss"] = self.loss.values[-1] log["Training iteration"] = self.training_iteration screen.log_dict(log, prefix="Training") else: # for the evaluation phase - logging as in regular RL super().log_to_screen()
def run_off_policy_evaluation(self): """ Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on an evaluation dataset, which was collected by another policy(ies). :return: None """ assert self.ope_manager if not isinstance(self.pre_network_filter, NoInputFilter) and len( self.pre_network_filter.reward_filters) != 0: raise ValueError( "Defining a pre-network reward filter when OPEs are calculated will result in a mismatch " "between q values (which are scaled), and actual rewards, which are not. It is advisable " "to use an input_filter, if possible, instead, which will filter the transitions directly " "in the replay buffer, affecting both the q_values and the rewards themselves. " ) ips, dm, dr, seq_dr, wis = self.ope_manager.evaluate( evaluation_dataset_as_episodes=self.memory. evaluation_dataset_as_episodes, evaluation_dataset_as_transitions=self.memory. evaluation_dataset_as_transitions, batch_size=self.ap.network_wrappers['main'].batch_size, discount_factor=self.ap.algorithm.discount, q_network=self.networks['main'].online_network, network_keys=list(self.ap.network_wrappers['main']. input_embedders_parameters.keys())) # get the estimators out to the screen log = OrderedDict() log['Epoch'] = self.training_epoch log['IPS'] = ips log['DM'] = dm log['DR'] = dr log['WIS'] = wis log['Sequential-DR'] = seq_dr screen.log_dict(log, prefix='Off-Policy Evaluation') # get the estimators out to dashboard self.agent_logger.set_current_time(self.get_current_time() + 1) self.agent_logger.create_signal_value('Inverse Propensity Score', ips) self.agent_logger.create_signal_value('Direct Method Reward', dm) self.agent_logger.create_signal_value('Doubly Robust', dr) self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr) self.agent_logger.create_signal_value('Weighted Importance Sampling', wis)
def log_to_screen(self) -> None: """ Write an episode summary line to the terminal :return: None """ # log to screen log = OrderedDict() log["Name"] = self.full_name_id if self.task_id is not None: log["Worker"] = self.task_id log["Episode"] = self.current_episode log["Total reward"] = np.round(self.total_reward_in_current_episode, 2) log["Exploration"] = np.round(self.exploration_policy.get_control_param(), 2) log["Steps"] = self.total_steps_counter log["Training iteration"] = self.training_iteration screen.log_dict(log, prefix=self.phase.value)
def run_off_policy_evaluation(self): """ Run the off-policy evaluation estimators to get a prediction for the performance of the current policy based on an evaluation dataset, which was collected by another policy(ies). :return: None """ assert self.ope_manager dataset_as_episodes = self.call_memory( 'get_all_complete_episodes_from_to', (self.call_memory('get_last_training_set_episode_id') + 1, self.call_memory('num_complete_episodes'))) if len(dataset_as_episodes) == 0: raise ValueError( 'train_to_eval_ratio is too high causing the evaluation set to be empty. ' 'Consider decreasing its value.') ips, dm, dr, seq_dr = self.ope_manager.evaluate( dataset_as_episodes=dataset_as_episodes, batch_size=self.ap.network_wrappers['main'].batch_size, discount_factor=self.ap.algorithm.discount, reward_model=self.networks['reward_model'].online_network, q_network=self.networks['main'].online_network, network_keys=list(self.ap.network_wrappers['main']. input_embedders_parameters.keys())) # get the estimators out to the screen log = OrderedDict() log['Epoch'] = self.training_epoch log['IPS'] = ips log['DM'] = dm log['DR'] = dr log['Sequential-DR'] = seq_dr screen.log_dict(log, prefix='Off-Policy Evaluation') # get the estimators out to dashboard self.agent_logger.set_current_time(self.get_current_time() + 1) self.agent_logger.create_signal_value('Inverse Propensity Score', ips) self.agent_logger.create_signal_value('Direct Method Reward', dm) self.agent_logger.create_signal_value('Doubly Robust', dr) self.agent_logger.create_signal_value('Sequential Doubly Robust', seq_dr)
def save_checkpoint(self): if self.task_parameters.checkpoint_save_dir is None: self.task_parameters.checkpoint_save_dir = os.path.join( self.task_parameters.experiment_path, 'checkpoint') checkpoint_path = os.path.join( self.task_parameters.checkpoint_save_dir, "{}_Step-{}.ckpt".format( self.checkpoint_id, self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) if not isinstance(self.task_parameters, DistributedTaskParameters): if self.checkpoint_saver is not None: saved_checkpoint_path = self.checkpoint_saver.save( self.sess, checkpoint_path) else: saved_checkpoint_path = "<Not Saved>" else: saved_checkpoint_path = checkpoint_path # this is required in order for agents to save additional information like a DND for example [ manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers ] # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used if self.task_parameters.export_onnx_graph: self.save_onnx_graph() screen.log_dict(OrderedDict([ ("Saving in path", saved_checkpoint_path), ]), prefix="Checkpoint") self.checkpoint_id += 1 self.last_checkpoint_saving_time = time.time() if hasattr(self, 'data_store_params'): data_store = self.get_data_store(self.data_store_params) data_store.save_to_store()
def improve_reward_model(self, epochs: int): """ Train a reward model to be used by the doubly-robust estimator :param epochs: The total number of epochs to use for training a reward model :return: None """ batch_size = self.ap.network_wrappers['reward_model'].batch_size # this is fitted from the training dataset for epoch in range(epochs): loss = 0 total_transitions_processed = 0 for i, batch in enumerate( self.call_memory('get_shuffled_training_data_generator', batch_size)): batch = Batch(batch) loss += self.get_reward_model_loss(batch) total_transitions_processed += batch.size log = OrderedDict() log['Epoch'] = epoch log['loss'] = loss / total_transitions_processed screen.log_dict(log, prefix='Training Reward Model')
def restore_checkpoint(self): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network if self.task_parameters.checkpoint_restore_path: restored_checkpoint_paths = [] for agent_params in self.agents_params: if len(self.agents_params) == 1: agent_checkpoint_restore_path = self.task_parameters.checkpoint_restore_path else: agent_checkpoint_restore_path = os.path.join( self.task_parameters.checkpoint_restore_path, agent_params.name) if os.path.isdir(agent_checkpoint_restore_path): # a checkpoint dir if self.task_parameters.framework_type == Frameworks.tensorflow and\ 'checkpoint' in os.listdir(agent_checkpoint_restore_path): # TODO-fixme checkpointing # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so, # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" # filename pattern. The names used are maintained in a CheckpointState protobuf file named # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to # restore the model, as the checkpoint names defined do not match the actual checkpoint names. raise NotImplementedError( 'Checkpointing not implemented for TF monitored training session' ) else: checkpoint = get_checkpoint_state( agent_checkpoint_restore_path, all_checkpoints=True) if checkpoint is None: raise ValueError( "No checkpoint to restore in: {}".format( agent_checkpoint_restore_path)) model_checkpoint_path = checkpoint.model_checkpoint_path checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path restored_checkpoint_paths.append(model_checkpoint_path) # Set the last checkpoint ID - only in the case of the path being a dir chkpt_state_reader = CheckpointStateReader( agent_checkpoint_restore_path, checkpoint_state_optional=False) self.checkpoint_id = chkpt_state_reader.get_latest( ).num + 1 else: # a checkpoint file if self.task_parameters.framework_type == Frameworks.tensorflow: model_checkpoint_path = agent_checkpoint_restore_path checkpoint_restore_dir = os.path.dirname( model_checkpoint_path) restored_checkpoint_paths.append(model_checkpoint_path) else: raise ValueError( "Currently restoring a checkpoint using the --checkpoint_restore_file argument is" " only supported when with tensorflow.") try: self.checkpoint_saver[agent_params.name].restore( self.sess[agent_params.name], model_checkpoint_path) except Exception as ex: raise ValueError( "Failed to restore {}'s checkpoint: {}".format( agent_params.name, ex)) all_checkpoints = sorted( list(set([c.name for c in checkpoint.all_checkpoints ]))) # remove duplicates :-( if self.num_checkpoints_to_keep < len(all_checkpoints): checkpoint_to_delete = all_checkpoints[ -self.num_checkpoints_to_keep - 1] agent_checkpoint_to_delete = os.path.join( agent_checkpoint_restore_path, checkpoint_to_delete) for file in glob.glob( "{}*".format(agent_checkpoint_to_delete)): os.remove(file) [ manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers ] [ manager.post_training_commands() for manager in self.level_managers ] screen.log_dict(OrderedDict([ ("Restoring from path", restored_checkpoint_path) for restored_checkpoint_path in restored_checkpoint_paths ]), prefix="Checkpoint")
def train_network(self, batch, epochs): batch_results = [] for j in range(epochs): batch.shuffle() batch_results = { 'total_loss': [], 'losses': [], 'unclipped_grads': [], 'kl_divergence': [], 'entropy': [] } fetches = [self.networks['main'].online_network.output_heads[1].kl_divergence, self.networks['main'].online_network.output_heads[1].entropy, self.networks['main'].online_network.output_heads[1].likelihood_ratio, self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio] # TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on # some of the data for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)): start = i * self.ap.network_wrappers['main'].batch_size end = (i + 1) * self.ap.network_wrappers['main'].batch_size network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() actions = batch.actions()[start:end] gae_based_value_targets = batch.info('gae_based_value_target')[start:end] if not isinstance(self.spaces.action, DiscreteActionSpace) and len(actions.shape) == 1: actions = np.expand_dims(actions, -1) # get old policy probabilities and distribution # TODO-perf - the target network ("old_policy") is not changing. this can be calculated once for all epochs. # the shuffling being done, should only be performed on the indices. result = self.networks['main'].target_network.predict({k: v[start:end] for k, v in batch.states(network_keys).items()}) old_policy_distribution = result[1:] total_returns = batch.n_step_discounted_rewards(expand_dims=True) # calculate gradients and apply on both the local policy network and on the global policy network if self.ap.algorithm.estimate_state_value_using_gae: value_targets = np.expand_dims(gae_based_value_targets, -1) else: value_targets = total_returns[start:end] inputs = copy.copy({k: v[start:end] for k, v in batch.states(network_keys).items()}) inputs['output_1_0'] = actions # The old_policy_distribution needs to be represented as a list, because in the event of # discrete controls, it has just a mean. otherwise, it has both a mean and standard deviation for input_index, input in enumerate(old_policy_distribution): inputs['output_1_{}'.format(input_index + 1)] = input # update the clipping decay schedule value inputs['output_1_{}'.format(len(old_policy_distribution)+1)] = \ self.ap.algorithm.clipping_decay_schedule.current_value total_loss, losses, unclipped_grads, fetch_result = \ self.networks['main'].train_and_sync_networks( inputs, [value_targets, batch.info('advantage')[start:end]], additional_fetches=fetches ) batch_results['total_loss'].append(total_loss) batch_results['losses'].append(losses) batch_results['unclipped_grads'].append(unclipped_grads) batch_results['kl_divergence'].append(fetch_result[0]) batch_results['entropy'].append(fetch_result[1]) self.unclipped_grads.add_sample(unclipped_grads) self.value_targets.add_sample(value_targets) self.likelihood_ratio.add_sample(fetch_result[2]) self.clipped_likelihood_ratio.add_sample(fetch_result[3]) for key in batch_results.keys(): batch_results[key] = np.mean(batch_results[key], 0) self.value_loss.add_sample(batch_results['losses'][0]) self.policy_loss.add_sample(batch_results['losses'][1]) self.loss.add_sample(batch_results['total_loss']) if self.ap.network_wrappers['main'].learning_rate_decay_rate != 0: curr_learning_rate = self.networks['main'].online_network.get_variable_value( self.networks['main'].online_network.adaptive_learning_rate_scheme) self.curr_learning_rate.add_sample(curr_learning_rate) else: curr_learning_rate = self.ap.network_wrappers['main'].learning_rate # log training parameters screen.log_dict( OrderedDict([ ("Surrogate loss", batch_results['losses'][1]), ("KL divergence", batch_results['kl_divergence']), ("Entropy", batch_results['entropy']), ("training epoch", j), ("learning_rate", curr_learning_rate) ]), prefix="Policy training" ) self.total_kl_divergence_during_training_process = batch_results['kl_divergence'] self.entropy.add_sample(batch_results['entropy']) self.kl_divergence.add_sample(batch_results['kl_divergence']) return batch_results['losses']
def train_policy_network(self, dataset, epochs): loss = [] for j in range(epochs): loss = { 'total_loss': [], 'policy_losses': [], 'unclipped_grads': [], 'fetch_result': [] } #shuffle(dataset) for i in range( len(dataset) // self.ap.network_wrappers['actor'].batch_size): batch = Batch( dataset[i * self.ap.network_wrappers['actor'].batch_size:(i + 1) * self.ap.network_wrappers['actor'].batch_size]) network_keys = self.ap.network_wrappers[ 'actor'].input_embedders_parameters.keys() advantages = batch.info('advantage') actions = batch.actions() if not isinstance(self.spaces.action, DiscreteActionSpace) and len( actions.shape) == 1: actions = np.expand_dims(actions, -1) # get old policy probabilities and distribution old_policy = force_list( self.networks['actor'].target_network.predict( batch.states(network_keys))) # calculate gradients and apply on both the local policy network and on the global policy network fetches = [ self.networks['actor'].online_network.output_heads[0]. kl_divergence, self.networks['actor'].online_network. output_heads[0].entropy ] inputs = copy.copy(batch.states(network_keys)) inputs['output_0_0'] = actions # old_policy_distribution needs to be represented as a list, because in the event of discrete controls, # it has just a mean. otherwise, it has both a mean and standard deviation for input_index, input in enumerate(old_policy): inputs['output_0_{}'.format(input_index + 1)] = input total_loss, policy_losses, unclipped_grads, fetch_result =\ self.networks['actor'].online_network.accumulate_gradients( inputs, [advantages], additional_fetches=fetches) self.networks['actor'].apply_gradients_to_online_network() if isinstance(self.ap.task_parameters, DistributedTaskParameters): self.networks['actor'].apply_gradients_to_global_network() self.networks[ 'actor'].online_network.reset_accumulated_gradients() loss['total_loss'].append(total_loss) loss['policy_losses'].append(policy_losses) loss['unclipped_grads'].append(unclipped_grads) loss['fetch_result'].append(fetch_result) self.unclipped_grads.add_sample(unclipped_grads) for key in loss.keys(): loss[key] = np.mean(loss[key], 0) if self.ap.network_wrappers['critic'].learning_rate_decay_rate != 0: curr_learning_rate = self.networks[ 'critic'].online_network.get_variable_value( self.ap.learning_rate) self.curr_learning_rate.add_sample(curr_learning_rate) else: curr_learning_rate = self.ap.network_wrappers[ 'critic'].learning_rate # log training parameters screen.log_dict(OrderedDict([ ("Surrogate loss", loss['policy_losses'][0]), ("KL divergence", loss['fetch_result'][0]), ("Entropy", loss['fetch_result'][1]), ("training epoch", j), ("learning_rate", curr_learning_rate) ]), prefix="Policy training") self.total_kl_divergence_during_training_process = loss[ 'fetch_result'][0] self.entropy.add_sample(loss['fetch_result'][1]) self.kl_divergence.add_sample(loss['fetch_result'][0]) return loss['total_loss']
def improve_reward_model(self, epochs: int): """ Train both a reward model to be used by the doubly-robust estimator, and some model to be used for BCQ :param epochs: The total number of epochs to use for training a reward model :return: None """ # we'll be assuming that these gets drawn from the reward model parameters batch_size = self.ap.network_wrappers['reward_model'].batch_size network_keys = self.ap.network_wrappers['reward_model'].input_embedders_parameters.keys() # if using a NN to decide which actions to drop, we'll train the NN here if isinstance(self.ap.algorithm.action_drop_method_parameters, NNImitationModelParameters): total_epochs = max(epochs, self.ap.algorithm.action_drop_method_parameters.imitation_model_num_epochs) else: total_epochs = epochs for epoch in range(total_epochs): # this is fitted from the training dataset reward_model_loss = 0 imitation_model_loss = 0 total_transitions_processed = 0 for i, batch in enumerate(self.call_memory('get_shuffled_training_data_generator', batch_size)): batch = Batch(batch) # reward model if epoch < epochs: reward_model_loss += self.get_reward_model_loss(batch) # imitation model if isinstance(self.ap.algorithm.action_drop_method_parameters, NNImitationModelParameters) and \ epoch < self.ap.algorithm.action_drop_method_parameters.imitation_model_num_epochs: target_actions = np.zeros((batch.size, len(self.spaces.action.actions))) target_actions[range(batch.size), batch.actions()] = 1 imitation_model_loss += self.networks['imitation_model'].train_and_sync_networks( batch.states(network_keys), target_actions)[0] total_transitions_processed += batch.size log = OrderedDict() log['Epoch'] = epoch if reward_model_loss: log['Reward Model Loss'] = reward_model_loss / total_transitions_processed if imitation_model_loss: log['Imitation Model Loss'] = imitation_model_loss / total_transitions_processed screen.log_dict(log, prefix='Training Batch RL Models') # if using a kNN based model, we'll initialize and build it here. # initialization cannot be moved to the constructor as we don't have the agent's spaces initialized yet. if isinstance(self.ap.algorithm.action_drop_method_parameters, KNNParameters): knn_size = self.ap.algorithm.action_drop_method_parameters.knn_size if self.ap.algorithm.action_drop_method_parameters.use_state_embedding_instead_of_state: self.knn_trees = [AnnoyDictionary( dict_size=knn_size, key_width=int(self.networks['reward_model'].online_network.state_embedding.shape[-1]), batch_size=knn_size) for _ in range(len(self.spaces.action.actions))] else: self.knn_trees = [AnnoyDictionary( dict_size=knn_size, key_width=self.spaces.state['observation'].shape[0], batch_size=knn_size) for _ in range(len(self.spaces.action.actions))] for i, knn_tree in enumerate(self.knn_trees): state_embeddings = self.embedding([transition.state for transition in self.memory.transitions if transition.action == i]) knn_tree.add( keys=state_embeddings, values=np.expand_dims(np.zeros(state_embeddings.shape[0]), axis=1)) for knn_tree in self.knn_trees: knn_tree._rebuild_index() self.average_dist = [[dist[0] for dist in knn_tree._get_k_nearest_neighbors_indices( keys=self.embedding([transition.state for transition in self.memory.transitions]), k=1)[0]] for knn_tree in self.knn_trees] self.average_dist = sum([x for l in self.average_dist for x in l]) # flatten and sum self.average_dist /= len(self.memory.transitions)