示例#1
0
文件: ppo.py 项目: ComMedX/deepchem-1
 def _build_model(self, model_dir):
     """Construct a KerasModel containing the policy and loss calculations."""
     state_shape = self._env.state_shape
     state_dtype = self._env.state_dtype
     if not self._state_is_list:
         state_shape = [state_shape]
         state_dtype = [state_dtype]
     features = []
     for s, d in zip(state_shape, state_dtype):
         features.append(
             tf.keras.layers.Input(shape=list(s), dtype=tf.as_dtype(d)))
     policy_model = self._policy.create_model()
     output_names = self._policy.output_names
     loss = PPOLoss(self.value_weight, self.entropy_weight,
                    self.clipping_width, output_names.index('action_prob'),
                    output_names.index('value'))
     model = KerasModel(policy_model,
                        loss,
                        batch_size=self.max_rollout_length,
                        model_dir=model_dir,
                        optimize=self._optimizer)
     env = self._env
     example_inputs = [
         np.zeros([model.batch_size] + list(shape), dtype)
         for shape, dtype in zip(state_shape, state_dtype)
     ]
     example_labels = [np.zeros((model.batch_size, env.n_actions))]
     example_weights = [np.zeros(model.batch_size)] * 3
     model._create_training_ops(
         (example_inputs, example_labels, example_weights))
     return model
示例#2
0
 def _build_model(self, model_dir):
     """Construct a KerasModel containing the policy and loss calculations."""
     policy_model = self._policy.create_model()
     loss = PPOLoss(self.value_weight, self.entropy_weight,
                    self.clipping_width, self._action_prob_index,
                    self._value_index)
     model = KerasModel(policy_model,
                        loss,
                        batch_size=self.max_rollout_length,
                        model_dir=model_dir,
                        optimize=self._optimizer)
     model._ensure_built()
     return model
示例#3
0
    def _build_model(self, model_dir):
        """Construct a KerasModel containing the policy and loss calculations."""
        policy_model = self._policy.create_model()
        if self.continuous:
            loss = A2CLossContinuous(self.value_weight, self.entropy_weight,
                                     self._action_mean_index,
                                     self._action_std_index, self._value_index)
        else:
            loss = A2CLossDiscrete(self.value_weight, self.entropy_weight,
                                   self._action_prob_index, self._value_index)
        model = KerasModel(policy_model,
                           loss,
                           batch_size=self.max_rollout_length,
                           model_dir=model_dir,
                           optimize=self._optimizer)
        model._ensure_built()

        return model