def _build_model(self, model_dir): """Construct a KerasModel containing the policy and loss calculations.""" state_shape = self._env.state_shape state_dtype = self._env.state_dtype if not self._state_is_list: state_shape = [state_shape] state_dtype = [state_dtype] features = [] for s, d in zip(state_shape, state_dtype): features.append( tf.keras.layers.Input(shape=list(s), dtype=tf.as_dtype(d))) policy_model = self._policy.create_model() output_names = self._policy.output_names loss = PPOLoss(self.value_weight, self.entropy_weight, self.clipping_width, output_names.index('action_prob'), output_names.index('value')) model = KerasModel(policy_model, loss, batch_size=self.max_rollout_length, model_dir=model_dir, optimize=self._optimizer) env = self._env example_inputs = [ np.zeros([model.batch_size] + list(shape), dtype) for shape, dtype in zip(state_shape, state_dtype) ] example_labels = [np.zeros((model.batch_size, env.n_actions))] example_weights = [np.zeros(model.batch_size)] * 3 model._create_training_ops( (example_inputs, example_labels, example_weights)) return model
def _build_model(self, model_dir): """Construct a KerasModel containing the policy and loss calculations.""" policy_model = self._policy.create_model() loss = PPOLoss(self.value_weight, self.entropy_weight, self.clipping_width, self._action_prob_index, self._value_index) model = KerasModel(policy_model, loss, batch_size=self.max_rollout_length, model_dir=model_dir, optimize=self._optimizer) model._ensure_built() return model
def _build_model(self, model_dir): """Construct a KerasModel containing the policy and loss calculations.""" policy_model = self._policy.create_model() if self.continuous: loss = A2CLossContinuous(self.value_weight, self.entropy_weight, self._action_mean_index, self._action_std_index, self._value_index) else: loss = A2CLossDiscrete(self.value_weight, self.entropy_weight, self._action_prob_index, self._value_index) model = KerasModel(policy_model, loss, batch_size=self.max_rollout_length, model_dir=model_dir, optimize=self._optimizer) model._ensure_built() return model