def sample_memories(self, batch_size, batch_first=False): """ :param batch_size: number of samples to return :param batch_first: If True, the first dimension of data is batch_size. If False (default), the first dimension is SEQ_LEN. Therefore, state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default, MDN-RNN consumes data with SEQ_LEN as the first dimension. """ sample_indices = np.random.randint(self.memory_size, size=batch_size) # state/next state shape: batch_size x seq_len x state_dim # action shape: # state shape: batch_size x seq_len x action_dim # reward/not_terminal shape: batch_size x seq_len state, action, next_state, reward, not_terminal = map( lambda x: torch.tensor(x, dtype=torch.float), zip(*self.deque_sample(sample_indices)), ) if not batch_first: state, action, next_state, reward, not_terminal = transpose( state, action, next_state, reward, not_terminal ) training_input = rlt.MemoryNetworkInput( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), next_state=next_state, reward=reward, not_terminal=not_terminal, ) return rlt.TrainingBatch(training_input=training_input, extras=None)
def embed_state(self, state): """ Embed state after either reset() or step() """ assert len(self.recent_states) == len(self.recent_actions) old_mdnrnn_mode = self.mdnrnn.mdnrnn.training self.mdnrnn.mdnrnn.eval() # Embed the state as the hidden layer's output # until the previous step + current state if len(self.recent_states) == 0: mdnrnn_state = np.zeros((1, self.raw_state_dim)) mdnrnn_action = np.zeros((1, self.action_dim)) else: mdnrnn_state = np.array(list(self.recent_states)) mdnrnn_action = np.array(list(self.recent_actions)) mdnrnn_state = torch.tensor(mdnrnn_state, dtype=torch.float).unsqueeze(1) mdnrnn_action = torch.tensor(mdnrnn_action, dtype=torch.float).unsqueeze(1) mdnrnn_input = rlt.StateAction( state=rlt.FeatureVector(float_features=mdnrnn_state), action=rlt.FeatureVector(float_features=mdnrnn_action), ) mdnrnn_output = self.mdnrnn(mdnrnn_input) hidden_embed = ( mdnrnn_output.all_steps_lstm_hidden[-1].squeeze().detach().numpy()) state_embed = np.hstack((hidden_embed, state)) self.mdnrnn.mdnrnn.train(old_mdnrnn_mode) logger.debug( "Embed_state\nrecent states: {}\nrecent actions: {}\nstate_embed{}\n" .format(np.array(self.recent_states), np.array(self.recent_actions), state_embed)) return state_embed
def input_prototype(self): return rlt.StateAction( state=rlt.FeatureVector( float_features=torch.randn(1, self.state_dim)), action=rlt.FeatureVector( float_features=torch.randn(1, self.action_dim)), )
def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) state = mt.FeatureVector(float_features=fetch(extract_record.state)) action = fetch_action(extract_record.action) reward = fetch(input_record.reward).reshape(-1, 1) # is_terminal should be filled by preprocessor if self.max_q_learning: if self.sorted_action_features is not None: next_state = None tiled_next_state = mt.FeatureVector( float_features=fetch(extract_record.tiled_next_state)) else: next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) tiled_next_state = None possible_next_actions = mt.PossibleActions( lengths=fetch(extract_record.possible_next_actions["lengths"]), actions=fetch_action( extract_record.possible_next_actions["values"]), ) training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, tiled_next_state=tiled_next_state, possible_next_actions=possible_next_actions, reward=reward, not_terminal=(possible_next_actions.lengths > 0).float().reshape(-1, 1), ) else: next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) next_action = fetch_action(extract_record.next_action) training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, # HACK: Need a better way to check this not_terminal=torch.ones_like(reward), ) # TODO: stuff other fields in here extras = mt.ExtraData(action_probability=fetch( input_record.action_probability).reshape(-1, 1)) return mt.TrainingBatch(training_input=training_input, extras=extras)
def as_discrete_maxq_training_batch(self): return rlt.TrainingBatch( training_input=rlt.MaxQLearningInput( state=rlt.FeatureVector(float_features=self.states), action=self.actions, next_state=rlt.FeatureVector(float_features=self.next_states), next_action=self.next_actions, tiled_next_state=None, possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=None, possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def input_prototype(self): if self.parametric_action: return rlt.StateAction( state=rlt.FeatureVector( float_features=torch.randn(1, self.state_dim)), action=rlt.FeatureVector( float_features=torch.randn(1, self.action_dim)), ) else: return rlt.StateInput(state=rlt.FeatureVector( float_features=torch.randn(1, self.state_dim)))
def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) state = mt.FeatureVector(float_features=fetch(extract_record.state)) if self.sorted_action_features is None: action = None else: action = mt.FeatureVector(float_features=fetch(extract_record.action)) return mt.StateAction(state=state, action=action)
def as_parametric_sarsa_training_batch(self): return rlt.TrainingBatch( training_input=rlt.SARSAInput( state=rlt.FeatureVector(float_features=self.states), action=rlt.FeatureVector(float_features=self.actions), next_state=rlt.FeatureVector(float_features=self.next_states), next_action=rlt.FeatureVector(float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminals, ), extras=rlt.ExtraData(), )
def internal_reward_estimation(self, state, action): """ Only used by Gym """ self.reward_network.eval() reward_estimates = self.reward_network( rlt.StateAction( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), )) self.reward_network.train() return reward_estimates.q_value.cpu()
def as_discrete_sarsa_training_batch(self): return rlt.TrainingBatch( training_input=rlt.SARSAInput( state=rlt.FeatureVector(float_features=self.states), action=self.actions, next_state=rlt.FeatureVector(float_features=self.next_states), next_action=self.next_actions, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def as_policy_network_training_batch(self): return rlt.TrainingBatch( training_input=rlt.PolicyNetworkInput( state=rlt.FeatureVector(float_features=self.states), action=rlt.FeatureVector(float_features=self.actions), next_state=rlt.FeatureVector(float_features=self.next_states), next_action=rlt.FeatureVector( float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def internal_reward_estimation(self, state, action): """ Only used by Gym """ self.reward_network.eval() with torch.no_grad(): state = torch.from_numpy(np.array(state)).type(self.dtype) action = torch.from_numpy(np.array(action)).type(self.dtype) reward_estimates = self.reward_network( rlt.StateAction( state=rlt.FeatureVector(float_features=state), action=rlt.FeatureVector(float_features=action), )) self.reward_network.train() return reward_estimates.q_value.cpu().data.numpy()
def extract(self, ws, input_record, extract_record): def fetch(b): data = ws.fetch_blob(str(b())) return torch.tensor(data) def fetch_action(b): if self.sorted_action_features is None: return fetch(b) else: return mt.FeatureVector(float_features=fetch(b)) state = mt.FeatureVector(float_features=fetch(extract_record.state)) next_state = mt.FeatureVector( float_features=fetch(extract_record.next_state)) action = fetch_action(extract_record.action) reward = fetch(input_record.reward) # is_terminal should be filled by preprocessor if self.max_q_learning: possible_next_actions = mt.PossibleActions( lengths=fetch(extract_record.possible_next_actions["lengths"]), actions=fetch_action( extract_record.possible_next_actions["values"]), ) training_input = mt.MaxQLearningInput( state=state, action=action, next_state=next_state, possible_next_actions=possible_next_actions, reward=reward, is_terminal=None, ) else: next_action = fetch_action(extract_record.next_action) training_input = mt.SARSAInput( state=state, action=action, next_state=next_state, next_action=next_action, reward=reward, is_terminal=None, ) # TODO: stuff other fields in here extras = None return mt.TrainingBatch(training_input=training_input, extras=extras)
def internal_prediction(self, states, noisy=False) -> np.ndarray: """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor.eval() # TODO: Handle states being sequences state_examples = rlt.FeatureVector( float_features=torch.from_numpy(np.array(states)).type(self.dtype) ) action = self.actor(rlt.StateAction(state=state_examples, action=None)).action self.actor.train() action = rescale_torch_tensor( action, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) action = action.cpu().data.numpy() if noisy: action = [x + (self.noise.get_noise()) for x in action] return np.array(action, dtype=np.float32)
def input_prototype(self): return rlt.StateInput( state=rlt.FeatureVector( float_features=torch.randn(1, self.state_dim), sequence_features=SequenceFeatures.prototype(), ) )
def preprocess(self, batch) -> rlt.RawTrainingBatch: state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor( batch["state_features"] ) next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor( batch["next_state_features"] ) mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1) sequence_numbers = torch.tensor( batch["sequence_number"], dtype=torch.int32 ).reshape(-1, 1) rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1) time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1) if "action_probability" in batch: propensities = torch.tensor( batch["action_probability"], dtype=torch.float32 ).reshape(-1, 1) else: propensities = torch.ones(rewards.shape, dtype=torch.float32) return rlt.RawTrainingBatch( training_input=rlt.RawBaseInput( # type: ignore state=rlt.FeatureVector( float_features=rlt.ValuePresence( value=state_features_dense, presence=state_features_dense_presence, ) ), next_state=rlt.FeatureVector( float_features=rlt.ValuePresence( value=next_state_features_dense, presence=next_state_features_dense_presence, ) ), reward=rewards, time_diff=time_diffs, step=None, not_terminal=None, ), extras=rlt.ExtraData( mdp_id=mdp_ids, sequence_number=sequence_numbers, action_probability=propensities, ), )
def _test_predictor_export(self, modular=False): """Verify that q-values before model export equal q-values after model export. Meant to catch issues with export logic.""" environment = Gridworld() samples = Samples( mdp_ids=["0"], sequence_numbers=[0], states=[{ 0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0 }], actions=["D"], action_probabilities=[0.5], rewards=[0], possible_actions=[["R", "D"]], next_states=[{ 5: 1.0 }], next_actions=["U"], terminals=[False], possible_next_actions=[["R", "U", "D"]], ) tdps = environment.preprocess_samples(samples, 1) if modular: trainer, exporter = self.get_modular_sarsa_trainer_exporter( environment, {}, False) input = rlt.StateInput(state=rlt.FeatureVector( float_features=tdps[0].states)) else: trainer, exporter = self.get_sarsa_trainer_exporter( environment, {}, False) input = tdps[0].states if modular: pre_export_q_values = trainer.q_network( input).q_values.detach().numpy() else: pre_export_q_values = trainer.q_network(input).detach().numpy() predictor = exporter.export() with tempfile.TemporaryDirectory() as tmpdirname: tmp_path = os.path.join(tmpdirname, "model") predictor.save(tmp_path, "minidb") new_predictor = DQNPredictor.load(tmp_path, "minidb") post_export_q_values = new_predictor.predict([samples.states[0]]) for i, action in enumerate(environment.ACTIONS): self.assertAlmostEquals(pre_export_q_values[0][i], post_export_q_values[0][action], places=4)
def internal_reward_estimation(self, input): """ Only used by Gym """ self.reward_network.eval() reward_estimates = self.reward_network( rlt.StateInput(rlt.FeatureVector(float_features=input))) self.reward_network.train() return reward_estimates.q_values.cpu()
def internal_reward_estimation(self, input): """ Only used by Gym """ self.reward_network.eval() with torch.no_grad(): input = torch.from_numpy(np.array(input)).type(self.dtype) reward_estimates = self.reward_network( rlt.StateInput(rlt.FeatureVector(float_features=input))) self.reward_network.train() return reward_estimates.q_values.cpu().data.numpy()
def as_discrete_sarsa_training_batch(self): return rlt.TrainingBatch( training_input=rlt.SARSAInput( state=rlt.FeatureVector(float_features=self.states), reward=self.rewards, time_diff=self.time_diffs, action=self.actions, next_action=self.next_actions, not_terminal=self.not_terminal, next_state=rlt.FeatureVector(float_features=self.next_states), step=self.step, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def input_prototype(self): return rlt.PreprocessedState( state=rlt.FeatureVector( float_features=torch.randn(1, self.state_dim), id_list_features={ "page_id": ( torch.zeros(1, dtype=torch.long), torch.ones(1, dtype=torch.long), ) }, ) )
def as_parametric_maxq_training_batch(self): state_dim = self.states.shape[1] return rlt.TrainingBatch( training_input=rlt.ParametricDqnInput( state=rlt.FeatureVector(float_features=self.states), action=rlt.FeatureVector(float_features=self.actions), next_state=rlt.FeatureVector(float_features=self.next_states), next_action=rlt.FeatureVector( float_features=self.next_actions), tiled_next_state=rlt.FeatureVector( float_features=self. possible_next_actions_state_concat[:, :state_dim]), possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=rlt.FeatureVector( float_features=self. possible_next_actions_state_concat[:, state_dim:]), possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def test_predictor_torch_export(self): """Verify that q-values before model export equal q-values after model export. Meant to catch issues with export logic.""" environment = Gridworld() samples = Samples( mdp_ids=["0"], sequence_numbers=[0], sequence_number_ordinals=[1], states=[{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0}], actions=["D"], action_probabilities=[0.5], rewards=[0], possible_actions=[["R", "D"]], next_states=[{5: 1.0}], next_actions=["U"], terminals=[False], possible_next_actions=[["R", "U", "D"]], ) tdps = environment.preprocess_samples(samples, 1) assert len(tdps) == 1, "Invalid number of data pages" trainer, exporter = self.get_modular_sarsa_trainer_exporter( environment, {}, False ) input = rlt.StateInput(state=rlt.FeatureVector(float_features=tdps[0].states)) pre_export_q_values = trainer.q_network(input).q_values.detach().numpy() preprocessor = Preprocessor(environment.normalization, False) serving_module = DiscreteDqnPredictorWrapper( state_preprocessor=preprocessor, value_network=trainer.q_network.cpu_model().fc, action_names=environment.ACTIONS, ) with tempfile.TemporaryDirectory() as tmpdirname: buf = export_module_to_buffer(serving_module) tmp_path = os.path.join(tmpdirname, "model") with open(tmp_path, "wb") as f: f.write(buf.getvalue()) f.close() predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path)) post_export_q_values = predictor.predict([samples.states[0]]) for i, action in enumerate(environment.ACTIONS): self.assertAlmostEqual( float(pre_export_q_values[0][i]), float(post_export_q_values[0][action]), places=4, )
def _maybe_scale_action_in_train(self, action): if (self.min_action_range_tensor_training is not None and self.max_action_range_tensor_training is not None and self.min_action_range_tensor_serving is not None and self.max_action_range_tensor_serving is not None): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, )) return action
def internal_prediction(self, input): """ Only used by Gym """ self.q_network.eval() q_values = self.q_network( rlt.StateInput(rlt.FeatureVector(float_features=input))) q_values = q_values.q_values.cpu() self.q_network.train() if self.bcq: action_preds = torch.tensor(self.bcq_imitator(input.cpu())) action_preds /= torch.max(action_preds, dim=1)[0] action_off_policy = (action_preds < self.bcq_drop_threshold).float() action_off_policy *= self.ACTION_NOT_POSSIBLE_VAL q_values += action_off_policy return q_values
def internal_prediction(self, states): """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor_network.eval() actions = self.actor_network( rlt.StateInput(rlt.FeatureVector(float_features=states))) # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(actions.action, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) self.actor_network.train() return rescaled_actions
def internal_prediction(self, states, test=False): """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor_network.eval() with torch.no_grad(): state_examples = torch.from_numpy(np.array(states)).type( self.dtype) actions = self.actor_network( rlt.StateInput( rlt.FeatureVector(float_features=state_examples))).action if not test: if self.minibatch < self.initial_exploration_ts: actions = (torch.rand_like(actions) * (self.max_action_range_tensor_training - self.min_action_range_tensor_training) + self.min_action_range_tensor_training) else: actions += torch.randn_like(actions) * self.exploration_noise # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(actions, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) self.actor_network.train() return rescaled_actions
def train(self, training_batch, evaluator=None) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 s = learning_input.state a = learning_input.action.float_features reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal current_state_action = rlt.StateAction( state=learning_input.state, action=learning_input.action ) q1_value = self.q1_network(current_state_action).q_value min_q_value = q1_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value min_q_value = torch.min(q1_value, q2_value) # Use the minimum as target, ensure no gradient going through min_q_value = min_q_value.detach() # # First, optimize value network; minimizing MSE between # V(s) & Q(s, a) - log(pi(a|s)) # state_value = self.value_network(s.float_features) # .q_value with torch.no_grad(): log_prob_a = self.actor_network.get_log_prob(s, a) target_value = min_q_value - self.entropy_temperature * log_prob_a value_loss = F.mse_loss(state_value, target_value) self.value_network_optimizer.zero_grad() value_loss.backward() self.value_network_optimizer.step() # # Second, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # with torch.no_grad(): next_state_value = ( self.value_network_target(learning_input.next_state.float_features) * not_done_mask ) if self.minibatch < self.reward_burnin: target_q_value = reward else: target_q_value = reward + discount * next_state_value q1_loss = F.mse_loss(q1_value, target_q_value) self.q1_network_optimizer.zero_grad() q1_loss.backward() self.q1_network_optimizer.step() if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) self.q2_network_optimizer.zero_grad() q2_loss.backward() self.q2_network_optimizer.step() # # Lastly, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # actor_output = self.actor_network(rlt.StateInput(state=learning_input.state)) state_actor_action = rlt.StateAction( state=s, action=rlt.FeatureVector(float_features=actor_output.action) ) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = torch.mean( self.entropy_temperature * actor_output.log_prob - min_q_actor_value ) self.actor_network_optimizer.zero_grad() actor_loss.backward() self.actor_network_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.value_network, self.value_network_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.value_network, self.value_network_target, self.tau) if evaluator is not None: # FIXME self.evaluate(evaluator)
def input_prototype(self): return rlt.StateInput( state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim)) )
def train(self, training_batch: rlt.TrainingBatch) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh action = rlt.FeatureVector( rescale_torch_tensor( learning_input.action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) ) rewards = learning_input.reward next_state = learning_input.next_state time_diffs = learning_input.time_diff discount_tensor = torch.full_like(rewards, self.gamma) not_done_mask = learning_input.not_terminal # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic.forward( rlt.StateAction(state=state, action=action) ).q_value next_action = rlt.FeatureVector( float_features=self.actor_target( rlt.StateAction(state=next_state, action=None) ).action ) q_s2_a2 = self.critic_target.forward( rlt.StateAction(state=next_state, action=next_action) ).q_value filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1 loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach()) loss_critic_for_eval = loss_critic.detach() self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max mean(Q(s1, a1)) or min -mean(Q(s1, a1)) actor_output = self.actor(rlt.StateAction(state=state, action=None)) loss_actor = -( self.critic.forward( rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action), ) ).q_value.mean() ) # Zero out both the actor and critic gradients because we need # to backprop through the critic to get to the actor self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) self.loss_reporter.report( td_loss=float(loss_critic_for_eval), reward_loss=None, model_values_on_logged_actions=critic_predictions, )