Пример #1
0
    def sample_memories(self, batch_size, batch_first=False):
        """
        :param batch_size: number of samples to return
        :param batch_first: If True, the first dimension of data is batch_size.
            If False (default), the first dimension is SEQ_LEN. Therefore,
            state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default,
            MDN-RNN consumes data with SEQ_LEN as the first dimension.
        """
        sample_indices = np.random.randint(self.memory_size, size=batch_size)
        # state/next state shape: batch_size x seq_len x state_dim
        # action shape: # state shape: batch_size x seq_len x action_dim
        # reward/not_terminal shape: batch_size x seq_len
        state, action, next_state, reward, not_terminal = map(
            lambda x: torch.tensor(x, dtype=torch.float),
            zip(*self.deque_sample(sample_indices)),
        )

        if not batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                state, action, next_state, reward, not_terminal
            )

        training_input = rlt.MemoryNetworkInput(
            state=rlt.FeatureVector(float_features=state),
            action=rlt.FeatureVector(float_features=action),
            next_state=next_state,
            reward=reward,
            not_terminal=not_terminal,
        )
        return rlt.TrainingBatch(training_input=training_input, extras=None)
Пример #2
0
    def embed_state(self, state):
        """ Embed state after either reset() or step() """
        assert len(self.recent_states) == len(self.recent_actions)
        old_mdnrnn_mode = self.mdnrnn.mdnrnn.training
        self.mdnrnn.mdnrnn.eval()

        # Embed the state as the hidden layer's output
        # until the previous step + current state
        if len(self.recent_states) == 0:
            mdnrnn_state = np.zeros((1, self.raw_state_dim))
            mdnrnn_action = np.zeros((1, self.action_dim))
        else:
            mdnrnn_state = np.array(list(self.recent_states))
            mdnrnn_action = np.array(list(self.recent_actions))

        mdnrnn_state = torch.tensor(mdnrnn_state,
                                    dtype=torch.float).unsqueeze(1)
        mdnrnn_action = torch.tensor(mdnrnn_action,
                                     dtype=torch.float).unsqueeze(1)
        mdnrnn_input = rlt.StateAction(
            state=rlt.FeatureVector(float_features=mdnrnn_state),
            action=rlt.FeatureVector(float_features=mdnrnn_action),
        )
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        hidden_embed = (
            mdnrnn_output.all_steps_lstm_hidden[-1].squeeze().detach().numpy())
        state_embed = np.hstack((hidden_embed, state))
        self.mdnrnn.mdnrnn.train(old_mdnrnn_mode)
        logger.debug(
            "Embed_state\nrecent states: {}\nrecent actions: {}\nstate_embed{}\n"
            .format(np.array(self.recent_states),
                    np.array(self.recent_actions), state_embed))
        return state_embed
Пример #3
0
 def input_prototype(self):
     return rlt.StateAction(
         state=rlt.FeatureVector(
             float_features=torch.randn(1, self.state_dim)),
         action=rlt.FeatureVector(
             float_features=torch.randn(1, self.action_dim)),
     )
Пример #4
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        action = fetch_action(extract_record.action)
        reward = fetch(input_record.reward).reshape(-1, 1)

        # is_terminal should be filled by preprocessor
        if self.max_q_learning:
            if self.sorted_action_features is not None:
                next_state = None
                tiled_next_state = mt.FeatureVector(
                    float_features=fetch(extract_record.tiled_next_state))
            else:
                next_state = mt.FeatureVector(
                    float_features=fetch(extract_record.next_state))
                tiled_next_state = None
            possible_next_actions = mt.PossibleActions(
                lengths=fetch(extract_record.possible_next_actions["lengths"]),
                actions=fetch_action(
                    extract_record.possible_next_actions["values"]),
            )

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                tiled_next_state=tiled_next_state,
                possible_next_actions=possible_next_actions,
                reward=reward,
                not_terminal=(possible_next_actions.lengths >
                              0).float().reshape(-1, 1),
            )
        else:
            next_state = mt.FeatureVector(
                float_features=fetch(extract_record.next_state))
            next_action = fetch_action(extract_record.next_action)
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                # HACK: Need a better way to check this
                not_terminal=torch.ones_like(reward),
            )

        # TODO: stuff other fields in here
        extras = mt.ExtraData(action_probability=fetch(
            input_record.action_probability).reshape(-1, 1))

        return mt.TrainingBatch(training_input=training_input, extras=extras)
Пример #5
0
 def as_discrete_maxq_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.MaxQLearningInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=self.actions,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=self.next_actions,
             tiled_next_state=None,
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=None,
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
Пример #6
0
 def input_prototype(self):
     if self.parametric_action:
         return rlt.StateAction(
             state=rlt.FeatureVector(
                 float_features=torch.randn(1, self.state_dim)),
             action=rlt.FeatureVector(
                 float_features=torch.randn(1, self.action_dim)),
         )
     else:
         return rlt.StateInput(state=rlt.FeatureVector(
             float_features=torch.randn(1, self.state_dim)))
Пример #7
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        if self.sorted_action_features is None:
            action = None
        else:
            action = mt.FeatureVector(float_features=fetch(extract_record.action))
        return mt.StateAction(state=state, action=action)
Пример #8
0
 def as_parametric_sarsa_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.SARSAInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=rlt.FeatureVector(float_features=self.actions),
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=rlt.FeatureVector(float_features=self.next_actions),
             reward=self.rewards,
             not_terminal=self.not_terminals,
         ),
         extras=rlt.ExtraData(),
     )
Пример #9
0
 def internal_reward_estimation(self, state, action):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     reward_estimates = self.reward_network(
         rlt.StateAction(
             state=rlt.FeatureVector(float_features=state),
             action=rlt.FeatureVector(float_features=action),
         ))
     self.reward_network.train()
     return reward_estimates.q_value.cpu()
Пример #10
0
 def as_discrete_sarsa_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.SARSAInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=self.actions,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=self.next_actions,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Пример #11
0
 def as_policy_network_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.PolicyNetworkInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=rlt.FeatureVector(float_features=self.actions),
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=rlt.FeatureVector(
                 float_features=self.next_actions),
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Пример #12
0
 def internal_reward_estimation(self, state, action):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     with torch.no_grad():
         state = torch.from_numpy(np.array(state)).type(self.dtype)
         action = torch.from_numpy(np.array(action)).type(self.dtype)
         reward_estimates = self.reward_network(
             rlt.StateAction(
                 state=rlt.FeatureVector(float_features=state),
                 action=rlt.FeatureVector(float_features=action),
             ))
     self.reward_network.train()
     return reward_estimates.q_value.cpu().data.numpy()
Пример #13
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        def fetch_action(b):
            if self.sorted_action_features is None:
                return fetch(b)
            else:
                return mt.FeatureVector(float_features=fetch(b))

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        next_state = mt.FeatureVector(
            float_features=fetch(extract_record.next_state))
        action = fetch_action(extract_record.action)
        reward = fetch(input_record.reward)

        # is_terminal should be filled by preprocessor
        if self.max_q_learning:
            possible_next_actions = mt.PossibleActions(
                lengths=fetch(extract_record.possible_next_actions["lengths"]),
                actions=fetch_action(
                    extract_record.possible_next_actions["values"]),
            )

            training_input = mt.MaxQLearningInput(
                state=state,
                action=action,
                next_state=next_state,
                possible_next_actions=possible_next_actions,
                reward=reward,
                is_terminal=None,
            )
        else:
            next_action = fetch_action(extract_record.next_action)
            training_input = mt.SARSAInput(
                state=state,
                action=action,
                next_state=next_state,
                next_action=next_action,
                reward=reward,
                is_terminal=None,
            )

        # TODO: stuff other fields in here
        extras = None

        return mt.TrainingBatch(training_input=training_input, extras=extras)
Пример #14
0
    def internal_prediction(self, states, noisy=False) -> np.ndarray:
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor.eval()
        # TODO: Handle states being sequences
        state_examples = rlt.FeatureVector(
            float_features=torch.from_numpy(np.array(states)).type(self.dtype)
        )
        action = self.actor(rlt.StateAction(state=state_examples, action=None)).action

        self.actor.train()

        action = rescale_torch_tensor(
            action,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        action = action.cpu().data.numpy()
        if noisy:
            action = [x + (self.noise.get_noise()) for x in action]

        return np.array(action, dtype=np.float32)
Пример #15
0
 def input_prototype(self):
     return rlt.StateInput(
         state=rlt.FeatureVector(
             float_features=torch.randn(1, self.state_dim),
             sequence_features=SequenceFeatures.prototype(),
         )
     )
Пример #16
0
    def preprocess(self, batch) -> rlt.RawTrainingBatch:
        state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor(
            batch["state_features"]
        )
        next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor(
            batch["next_state_features"]
        )

        mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1)
        sequence_numbers = torch.tensor(
            batch["sequence_number"], dtype=torch.int32
        ).reshape(-1, 1)
        rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1)
        time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1)
        if "action_probability" in batch:
            propensities = torch.tensor(
                batch["action_probability"], dtype=torch.float32
            ).reshape(-1, 1)
        else:
            propensities = torch.ones(rewards.shape, dtype=torch.float32)

        return rlt.RawTrainingBatch(
            training_input=rlt.RawBaseInput(  # type: ignore
                state=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(
                        value=state_features_dense,
                        presence=state_features_dense_presence,
                    )
                ),
                next_state=rlt.FeatureVector(
                    float_features=rlt.ValuePresence(
                        value=next_state_features_dense,
                        presence=next_state_features_dense_presence,
                    )
                ),
                reward=rewards,
                time_diff=time_diffs,
                step=None,
                not_terminal=None,
            ),
            extras=rlt.ExtraData(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                action_probability=propensities,
            ),
        )
Пример #17
0
    def _test_predictor_export(self, modular=False):
        """Verify that q-values before model export equal q-values after
        model export. Meant to catch issues with export logic."""
        environment = Gridworld()
        samples = Samples(
            mdp_ids=["0"],
            sequence_numbers=[0],
            states=[{
                0: 1.0,
                1: 1.0,
                2: 1.0,
                3: 1.0,
                4: 1.0,
                5: 1.0,
                15: 1.0,
                24: 1.0
            }],
            actions=["D"],
            action_probabilities=[0.5],
            rewards=[0],
            possible_actions=[["R", "D"]],
            next_states=[{
                5: 1.0
            }],
            next_actions=["U"],
            terminals=[False],
            possible_next_actions=[["R", "U", "D"]],
        )
        tdps = environment.preprocess_samples(samples, 1)

        if modular:
            trainer, exporter = self.get_modular_sarsa_trainer_exporter(
                environment, {}, False)
            input = rlt.StateInput(state=rlt.FeatureVector(
                float_features=tdps[0].states))
        else:
            trainer, exporter = self.get_sarsa_trainer_exporter(
                environment, {}, False)
            input = tdps[0].states

        if modular:
            pre_export_q_values = trainer.q_network(
                input).q_values.detach().numpy()
        else:
            pre_export_q_values = trainer.q_network(input).detach().numpy()

        predictor = exporter.export()
        with tempfile.TemporaryDirectory() as tmpdirname:
            tmp_path = os.path.join(tmpdirname, "model")
            predictor.save(tmp_path, "minidb")
            new_predictor = DQNPredictor.load(tmp_path, "minidb")

        post_export_q_values = new_predictor.predict([samples.states[0]])

        for i, action in enumerate(environment.ACTIONS):
            self.assertAlmostEquals(pre_export_q_values[0][i],
                                    post_export_q_values[0][action],
                                    places=4)
Пример #18
0
 def internal_reward_estimation(self, input):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     reward_estimates = self.reward_network(
         rlt.StateInput(rlt.FeatureVector(float_features=input)))
     self.reward_network.train()
     return reward_estimates.q_values.cpu()
Пример #19
0
 def internal_reward_estimation(self, input):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     with torch.no_grad():
         input = torch.from_numpy(np.array(input)).type(self.dtype)
         reward_estimates = self.reward_network(
             rlt.StateInput(rlt.FeatureVector(float_features=input)))
     self.reward_network.train()
     return reward_estimates.q_values.cpu().data.numpy()
Пример #20
0
 def as_discrete_sarsa_training_batch(self):
     return rlt.TrainingBatch(
         training_input=rlt.SARSAInput(
             state=rlt.FeatureVector(float_features=self.states),
             reward=self.rewards,
             time_diff=self.time_diffs,
             action=self.actions,
             next_action=self.next_actions,
             not_terminal=self.not_terminal,
             next_state=rlt.FeatureVector(float_features=self.next_states),
             step=self.step,
         ),
         extras=rlt.ExtraData(
             mdp_id=self.mdp_ids,
             sequence_number=self.sequence_numbers,
             action_probability=self.propensities,
             max_num_actions=self.max_num_actions,
             metrics=self.metrics,
         ),
     )
Пример #21
0
 def input_prototype(self):
     return rlt.PreprocessedState(
         state=rlt.FeatureVector(
             float_features=torch.randn(1, self.state_dim),
             id_list_features={
                 "page_id": (
                     torch.zeros(1, dtype=torch.long),
                     torch.ones(1, dtype=torch.long),
                 )
             },
         )
     )
Пример #22
0
 def as_parametric_maxq_training_batch(self):
     state_dim = self.states.shape[1]
     return rlt.TrainingBatch(
         training_input=rlt.ParametricDqnInput(
             state=rlt.FeatureVector(float_features=self.states),
             action=rlt.FeatureVector(float_features=self.actions),
             next_state=rlt.FeatureVector(float_features=self.next_states),
             next_action=rlt.FeatureVector(
                 float_features=self.next_actions),
             tiled_next_state=rlt.FeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, :state_dim]),
             possible_actions=None,
             possible_actions_mask=self.possible_actions_mask,
             possible_next_actions=rlt.FeatureVector(
                 float_features=self.
                 possible_next_actions_state_concat[:, state_dim:]),
             possible_next_actions_mask=self.possible_next_actions_mask,
             reward=self.rewards,
             not_terminal=self.not_terminal,
             step=self.step,
             time_diff=self.time_diffs,
         ),
         extras=rlt.ExtraData(),
     )
Пример #23
0
    def test_predictor_torch_export(self):
        """Verify that q-values before model export equal q-values after
        model export. Meant to catch issues with export logic."""
        environment = Gridworld()
        samples = Samples(
            mdp_ids=["0"],
            sequence_numbers=[0],
            sequence_number_ordinals=[1],
            states=[{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 15: 1.0, 24: 1.0}],
            actions=["D"],
            action_probabilities=[0.5],
            rewards=[0],
            possible_actions=[["R", "D"]],
            next_states=[{5: 1.0}],
            next_actions=["U"],
            terminals=[False],
            possible_next_actions=[["R", "U", "D"]],
        )
        tdps = environment.preprocess_samples(samples, 1)
        assert len(tdps) == 1, "Invalid number of data pages"

        trainer, exporter = self.get_modular_sarsa_trainer_exporter(
            environment, {}, False
        )
        input = rlt.StateInput(state=rlt.FeatureVector(float_features=tdps[0].states))

        pre_export_q_values = trainer.q_network(input).q_values.detach().numpy()

        preprocessor = Preprocessor(environment.normalization, False)
        serving_module = DiscreteDqnPredictorWrapper(
            state_preprocessor=preprocessor,
            value_network=trainer.q_network.cpu_model().fc,
            action_names=environment.ACTIONS,
        )

        with tempfile.TemporaryDirectory() as tmpdirname:
            buf = export_module_to_buffer(serving_module)
            tmp_path = os.path.join(tmpdirname, "model")
            with open(tmp_path, "wb") as f:
                f.write(buf.getvalue())
                f.close()
                predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path))

        post_export_q_values = predictor.predict([samples.states[0]])

        for i, action in enumerate(environment.ACTIONS):
            self.assertAlmostEqual(
                float(pre_export_q_values[0][i]),
                float(post_export_q_values[0][action]),
                places=4,
            )
Пример #24
0
 def _maybe_scale_action_in_train(self, action):
     if (self.min_action_range_tensor_training is not None
             and self.max_action_range_tensor_training is not None
             and self.min_action_range_tensor_serving is not None
             and self.max_action_range_tensor_serving is not None):
         action = rlt.FeatureVector(
             rescale_torch_tensor(
                 action.float_features,
                 new_min=self.min_action_range_tensor_training,
                 new_max=self.max_action_range_tensor_training,
                 prev_min=self.min_action_range_tensor_serving,
                 prev_max=self.max_action_range_tensor_serving,
             ))
     return action
Пример #25
0
    def internal_prediction(self, input):
        """
        Only used by Gym
        """
        self.q_network.eval()
        q_values = self.q_network(
            rlt.StateInput(rlt.FeatureVector(float_features=input)))
        q_values = q_values.q_values.cpu()
        self.q_network.train()

        if self.bcq:
            action_preds = torch.tensor(self.bcq_imitator(input.cpu()))
            action_preds /= torch.max(action_preds, dim=1)[0]
            action_off_policy = (action_preds <
                                 self.bcq_drop_threshold).float()
            action_off_policy *= self.ACTION_NOT_POSSIBLE_VAL
            q_values += action_off_policy

        return q_values
Пример #26
0
    def internal_prediction(self, states):
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor_network.eval()
        actions = self.actor_network(
            rlt.StateInput(rlt.FeatureVector(float_features=states)))
        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(actions.action, self.max_action_range_tensor_training),
            self.min_action_range_tensor_training,
        )
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        self.actor_network.train()
        return rescaled_actions
Пример #27
0
    def internal_prediction(self, states, test=False):
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor_network.eval()
        with torch.no_grad():
            state_examples = torch.from_numpy(np.array(states)).type(
                self.dtype)
            actions = self.actor_network(
                rlt.StateInput(
                    rlt.FeatureVector(float_features=state_examples))).action

        if not test:
            if self.minibatch < self.initial_exploration_ts:
                actions = (torch.rand_like(actions) *
                           (self.max_action_range_tensor_training -
                            self.min_action_range_tensor_training) +
                           self.min_action_range_tensor_training)
            else:
                actions += torch.randn_like(actions) * self.exploration_noise

        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(actions, self.max_action_range_tensor_training),
            self.min_action_range_tensor_training,
        )
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        self.actor_network.train()
        return rescaled_actions
Пример #28
0
    def train(self, training_batch, evaluator=None) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        s = learning_input.state
        a = learning_input.action.float_features
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        current_state_action = rlt.StateAction(
            state=learning_input.state, action=learning_input.action
        )

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(s.float_features)  # .q_value

        with torch.no_grad():
            log_prob_a = self.actor_network.get_log_prob(s, a)
            target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask
            )

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=learning_input.state))

        state_actor_action = rlt.StateAction(
            state=s, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = torch.mean(
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        self.actor_network_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target, self.tau)

        if evaluator is not None:
            # FIXME
            self.evaluate(evaluator)
Пример #29
0
 def input_prototype(self):
     return rlt.StateInput(
         state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim))
     )
Пример #30
0
    def train(self, training_batch: rlt.TrainingBatch) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state

        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh
        action = rlt.FeatureVector(
            rescale_torch_tensor(
                learning_input.action.float_features,
                new_min=self.min_action_range_tensor_training,
                new_max=self.max_action_range_tensor_training,
                prev_min=self.min_action_range_tensor_serving,
                prev_max=self.max_action_range_tensor_serving,
            )
        )

        rewards = learning_input.reward
        next_state = learning_input.next_state
        time_diffs = learning_input.time_diff
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = learning_input.not_terminal

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic.forward(
            rlt.StateAction(state=state, action=action)
        ).q_value
        next_action = rlt.FeatureVector(
            float_features=self.actor_target(
                rlt.StateAction(state=next_state, action=None)
            ).action
        )

        q_s2_a2 = self.critic_target.forward(
            rlt.StateAction(state=next_state, action=next_action)
        ).q_value
        filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1
        loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach())
        loss_critic_for_eval = loss_critic.detach()
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max mean(Q(s1, a1)) or min -mean(Q(s1, a1))
        actor_output = self.actor(rlt.StateAction(state=state, action=None))
        loss_actor = -(
            self.critic.forward(
                rlt.StateAction(
                    state=state,
                    action=rlt.FeatureVector(float_features=actor_output.action),
                )
            ).q_value.mean()
        )

        # Zero out both the actor and critic gradients because we need
        #   to backprop through the critic to get to the actor
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # Use the soft update rule to update both target networks
        self._soft_update(self.actor, self.actor_target, self.tau)
        self._soft_update(self.critic, self.critic_target, self.tau)

        self.loss_reporter.report(
            td_loss=float(loss_critic_for_eval),
            reward_loss=None,
            model_values_on_logged_actions=critic_predictions,
        )